Merge pull request #23 from Monadical-SAS/whisper-jax-gokul

Whisper jax gokul minor refactor
This commit is contained in:
projects-g
2023-07-10 22:41:39 +05:30
committed by GitHub
10 changed files with 50 additions and 22 deletions

2
.gitignore vendored
View File

@@ -167,7 +167,7 @@ transcript_timestamps.txt
*.pkl
transcript_*.txt
test_*.txt
*.png
wordcloud*.png
*.ini
test_samples/
*.wav

15
scripts/clear_artefacts.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
# Directory to search for Python files
directory="."
# Pattern to match Python files (e.g., "*.py" for all .py files)
text_file_pattern="transcript_*.txt"
pickle_file_pattern="*.pkl"
html_file_pattern="*.html"
png_file_pattern="wordcloud*.png"
find "$directory" -type f -name "$text_file_pattern" -delete
find "$directory" -type f -name "$pickle_file_pattern" -delete
find "$directory" -type f -name "$html_file_pattern" -delete
find "$directory" -type f -name "$png_file_pattern" -delete

View File

@@ -26,7 +26,7 @@ pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
# Update to latest version
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
pip install -r requirements.txt
pip install -r ../requirements.txt
# download spacy models
spacy download en_core_web_sm

View File

@@ -15,7 +15,7 @@ from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from whisper_jax import FlaxWhisperPipline
from reflector.utils.server_utils import run_in_executor
from utils.server_utils import run_in_executor
logger = logging.getLogger(__name__)

View File

@@ -47,8 +47,6 @@ total_bytes_handled = 0
executor = ThreadPoolExecutor()
frame_lock = threading.Lock()
file_lock = threading.Lock()
total_bytes_handled_lock = threading.Lock()
def channel_log(channel, t, message):

View File

@@ -11,9 +11,10 @@ import ast
import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.server_utils import Mutex
logger = logging.getLogger("pc")
file_lock = threading.Lock()
file_lock = Mutex(open("test_sm_6.txt", "a"))
config = configparser.ConfigParser()
config.read('config.ini')
@@ -24,8 +25,7 @@ class StreamClient:
signaling,
url="http://127.0.0.1:1250",
play_from=None,
ping_pong=False,
audio_stream=None
ping_pong=False
):
self.signaling = signaling
self.server_url = url
@@ -36,7 +36,6 @@ class StreamClient:
self.pc = RTCPeerConnection()
self.loop = asyncio.get_event_loop()
# self.loop = asyncio.new_event_loop()
self.relay = None
self.pcs = set()
self.time_start = None
@@ -68,7 +67,6 @@ class StreamClient:
channel.send(message)
def current_stamp(self):
if self.time_start is None:
self.time_start = time.time()
return 0
@@ -94,9 +92,7 @@ class StreamClient:
@pc.on("track")
def on_track(track):
print("Sending %s" % track.kind)
# Trials
self.pc.addTrack(track)
# self.pc.addTrack(self.microphone)
@track.on("ended")
async def on_ended():
@@ -104,7 +100,6 @@ class StreamClient:
self.pc.addTrack(audio)
# DataChannel
channel = pc.createDataChannel("data-channel")
self.channel_log(channel, "-", "created by local party")
@@ -155,14 +150,12 @@ class StreamClient:
while True:
msg = await self.queue.get()
msg = ast.literal_eval(msg)
with file_lock:
with open("test_sm_6.txt", "a") as f:
f.write(msg["text"])
with file_lock.lock() as file:
file.write(msg["text"])
yield msg["text"]
self.queue.task_done()
async def start(self):
print("Starting stream client")
coro = self.run_offer(self.pc, self.signaling)
task = asyncio.create_task(coro)
await task

View File

@@ -1,7 +1,25 @@
import asyncio
from functools import partial
import contextlib
from threading import Lock
from typing import ContextManager, Generic, TypeVar
def run_in_executor(func, *args, executor=None, **kwargs):
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return asyncio.get_event_loop().run_in_executor(executor, callback)
T = TypeVar("T")
class Mutex(Generic[T]):
def __init__(self, value: T):
self.__value = value
self.__lock = Lock()
@contextlib.contextmanager
def lock(self) -> ContextManager[T]:
self.__lock.acquire()
try:
yield self.__value
finally:
self.__lock.release()

View File

@@ -19,6 +19,9 @@ def preprocess_sentence(sentence):
return ' '.join(tokens)
def compute_similarity(sent1, sent2):
"""
Compute the similarity
"""
tfidf_vectorizer = TfidfVectorizer()
if sent1 is not None and sent2 is not None:
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])

View File

@@ -19,6 +19,7 @@ spacy_stopwords = en.Defaults.stop_words
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
def create_wordcloud(timestamp, real_time=False):
"""
Create a basic word cloud visualization of transcribed text
@@ -30,7 +31,7 @@ def create_wordcloud(timestamp, real_time=False):
else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open(filename, "r") as f:
with open("./artefacts/" + filename, "r") as f:
transcription_text = f.read()
# python_mask = np.array(PIL.Image.open("download1.png"))
@@ -199,6 +200,6 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
transform=st.Scalers.dense_rank
)
if real_time:
open('./real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
else:
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -137,10 +137,10 @@ def main():
for chunk in whisper_result["chunks"]:
transcript_text += chunk["text"]
with open("transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
transcript_file.write(transcript_text)
with open("transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result))