Merge pull request #6 from Monadical-SAS/whisper-jax-gokul

Push whisper-jax real time trial code : https://github.com/Monadical-SAS/reflector/issues/3
This commit is contained in:
projects-g
2023-06-09 16:04:43 +05:30
committed by GitHub
3 changed files with 93 additions and 3 deletions

8
.gitignore vendored
View File

@@ -160,6 +160,8 @@ cython_debug/
#.idea/
*.mp4
*.txt
config.ini
test_samples/
summary.txt
transcript.txt
*.ini
test_samples/
*.wav

4
requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
pyaudio==0.2.13
keyboard==0.13.5
pynput==1.7.6
wave==0.0.2

84
whisjax_realtime_trial.py Normal file
View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python3
import configparser
import pyaudio
from whisper_jax import FlaxWhisperPipline
from pynput import keyboard
import jax.numpy as jnp
import wave
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
OPENAI_APIKEY = config['DEFAULT']["OPENAI_APIKEY"]
FRAMES_PER_BUFFER = 8000
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
RECORD_SECONDS = 10
def main():
p = pyaudio.PyAudio()
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=FRAMES_PER_BUFFER
)
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16,
batch_size=16)
transcript_file = open("transcript.txt", "w+")
transcription = ""
TEMP_AUDIO_FILE = "temp_audio.wav"
global proceed
proceed = True
def on_press(key):
if key == keyboard.Key.esc:
global proceed
proceed = False
listener = keyboard.Listener(on_press=on_press)
listener.start()
while proceed:
try:
frames = []
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
frames.append(data)
print("Collected Input", len(frames))
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
print(whisper_result['text'])
transcription += whisper_result['text']
if len(transcription) > 10:
transcription += "\n"
transcript_file.write(transcription)
transcription = ""
except Exception as e:
print(e)
break
if __name__ == "__main__":
main()