diff --git a/.gitignore b/.gitignore index a63da56f..cc94bff3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,8 @@ cython_debug/ #.idea/ *.mp4 -*.txt -config.ini -test_samples/ \ No newline at end of file +summary.txt +transcript.txt +*.ini +test_samples/ +*.wav \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..13799945 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +pyaudio==0.2.13 +keyboard==0.13.5 +pynput==1.7.6 +wave==0.0.2 \ No newline at end of file diff --git a/whisjax_realtime_trial.py b/whisjax_realtime_trial.py new file mode 100644 index 00000000..87be0c7b --- /dev/null +++ b/whisjax_realtime_trial.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +import configparser +import pyaudio +from whisper_jax import FlaxWhisperPipline +from pynput import keyboard +import jax.numpy as jnp +import wave + +config = configparser.ConfigParser() +config.read('config.ini') + +WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] +OPENAI_APIKEY = config['DEFAULT']["OPENAI_APIKEY"] + +FRAMES_PER_BUFFER = 8000 +FORMAT = pyaudio.paInt16 +CHANNELS = 1 +RATE = 44100 +RECORD_SECONDS = 10 + + +def main(): + p = pyaudio.PyAudio() + + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=FRAMES_PER_BUFFER + ) + + pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, + dtype=jnp.float16, + batch_size=16) + + transcript_file = open("transcript.txt", "w+") + transcription = "" + + TEMP_AUDIO_FILE = "temp_audio.wav" + global proceed + proceed = True + + def on_press(key): + if key == keyboard.Key.esc: + global proceed + proceed = False + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + + while proceed: + try: + frames = [] + for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): + data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) + frames.append(data) + print("Collected Input", len(frames)) + + wf = wave.open(TEMP_AUDIO_FILE, 'wb') + wf.setnchannels(CHANNELS) + wf.setsampwidth(p.get_sample_size(FORMAT)) + wf.setframerate(RATE) + wf.writeframes(b''.join(frames)) + wf.close() + + whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True) + print(whisper_result['text']) + + transcription += whisper_result['text'] + if len(transcription) > 10: + transcription += "\n" + transcript_file.write(transcription) + transcription = "" + + except Exception as e: + print(e) + break + + +if __name__ == "__main__": + main()