mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
Merge pull request #23 from Monadical-SAS/whisper-jax-gokul
Whisper jax gokul minor refactor
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -167,7 +167,7 @@ transcript_timestamps.txt
|
||||
*.pkl
|
||||
transcript_*.txt
|
||||
test_*.txt
|
||||
*.png
|
||||
wordcloud*.png
|
||||
*.ini
|
||||
test_samples/
|
||||
*.wav
|
||||
|
||||
15
scripts/clear_artefacts.sh
Executable file
15
scripts/clear_artefacts.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Directory to search for Python files
|
||||
directory="."
|
||||
|
||||
# Pattern to match Python files (e.g., "*.py" for all .py files)
|
||||
text_file_pattern="transcript_*.txt"
|
||||
pickle_file_pattern="*.pkl"
|
||||
html_file_pattern="*.html"
|
||||
png_file_pattern="wordcloud*.png"
|
||||
|
||||
find "$directory" -type f -name "$text_file_pattern" -delete
|
||||
find "$directory" -type f -name "$pickle_file_pattern" -delete
|
||||
find "$directory" -type f -name "$html_file_pattern" -delete
|
||||
find "$directory" -type f -name "$png_file_pattern" -delete
|
||||
@@ -26,7 +26,7 @@ pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
|
||||
# Update to latest version
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
|
||||
|
||||
pip install -r requirements.txt
|
||||
pip install -r ../requirements.txt
|
||||
|
||||
# download spacy models
|
||||
spacy download en_core_web_sm
|
||||
@@ -15,7 +15,7 @@ from aiortc.contrib.media import MediaRelay
|
||||
from av import AudioFifo
|
||||
from whisper_jax import FlaxWhisperPipline
|
||||
|
||||
from reflector.utils.server_utils import run_in_executor
|
||||
from utils.server_utils import run_in_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -47,8 +47,6 @@ total_bytes_handled = 0
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
frame_lock = threading.Lock()
|
||||
file_lock = threading.Lock()
|
||||
|
||||
total_bytes_handled_lock = threading.Lock()
|
||||
|
||||
def channel_log(channel, t, message):
|
||||
|
||||
@@ -11,9 +11,10 @@ import ast
|
||||
import stamina
|
||||
from aiortc import (RTCPeerConnection, RTCSessionDescription)
|
||||
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
|
||||
from utils.server_utils import Mutex
|
||||
|
||||
logger = logging.getLogger("pc")
|
||||
file_lock = threading.Lock()
|
||||
file_lock = Mutex(open("test_sm_6.txt", "a"))
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read('config.ini')
|
||||
@@ -24,8 +25,7 @@ class StreamClient:
|
||||
signaling,
|
||||
url="http://127.0.0.1:1250",
|
||||
play_from=None,
|
||||
ping_pong=False,
|
||||
audio_stream=None
|
||||
ping_pong=False
|
||||
):
|
||||
self.signaling = signaling
|
||||
self.server_url = url
|
||||
@@ -36,7 +36,6 @@ class StreamClient:
|
||||
self.pc = RTCPeerConnection()
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
# self.loop = asyncio.new_event_loop()
|
||||
self.relay = None
|
||||
self.pcs = set()
|
||||
self.time_start = None
|
||||
@@ -68,7 +67,6 @@ class StreamClient:
|
||||
channel.send(message)
|
||||
|
||||
def current_stamp(self):
|
||||
|
||||
if self.time_start is None:
|
||||
self.time_start = time.time()
|
||||
return 0
|
||||
@@ -94,9 +92,7 @@ class StreamClient:
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print("Sending %s" % track.kind)
|
||||
# Trials
|
||||
self.pc.addTrack(track)
|
||||
# self.pc.addTrack(self.microphone)
|
||||
|
||||
@track.on("ended")
|
||||
async def on_ended():
|
||||
@@ -104,7 +100,6 @@ class StreamClient:
|
||||
|
||||
self.pc.addTrack(audio)
|
||||
|
||||
# DataChannel
|
||||
channel = pc.createDataChannel("data-channel")
|
||||
self.channel_log(channel, "-", "created by local party")
|
||||
|
||||
@@ -155,14 +150,12 @@ class StreamClient:
|
||||
while True:
|
||||
msg = await self.queue.get()
|
||||
msg = ast.literal_eval(msg)
|
||||
with file_lock:
|
||||
with open("test_sm_6.txt", "a") as f:
|
||||
f.write(msg["text"])
|
||||
with file_lock.lock() as file:
|
||||
file.write(msg["text"])
|
||||
yield msg["text"]
|
||||
self.queue.task_done()
|
||||
|
||||
async def start(self):
|
||||
print("Starting stream client")
|
||||
coro = self.run_offer(self.pc, self.signaling)
|
||||
task = asyncio.create_task(coro)
|
||||
await task
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import contextlib
|
||||
from threading import Lock
|
||||
from typing import ContextManager, Generic, TypeVar
|
||||
|
||||
def run_in_executor(func, *args, executor=None, **kwargs):
|
||||
callback = partial(func, *args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
return asyncio.get_event_loop().run_in_executor(executor, callback)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
class Mutex(Generic[T]):
|
||||
def __init__(self, value: T):
|
||||
self.__value = value
|
||||
self.__lock = Lock()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock(self) -> ContextManager[T]:
|
||||
self.__lock.acquire()
|
||||
try:
|
||||
yield self.__value
|
||||
finally:
|
||||
self.__lock.release()
|
||||
@@ -19,6 +19,9 @@ def preprocess_sentence(sentence):
|
||||
return ' '.join(tokens)
|
||||
|
||||
def compute_similarity(sent1, sent2):
|
||||
"""
|
||||
Compute the similarity
|
||||
"""
|
||||
tfidf_vectorizer = TfidfVectorizer()
|
||||
if sent1 is not None and sent2 is not None:
|
||||
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
||||
|
||||
@@ -19,6 +19,7 @@ spacy_stopwords = en.Defaults.stop_words
|
||||
|
||||
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
|
||||
|
||||
|
||||
def create_wordcloud(timestamp, real_time=False):
|
||||
"""
|
||||
Create a basic word cloud visualization of transcribed text
|
||||
@@ -30,7 +31,7 @@ def create_wordcloud(timestamp, real_time=False):
|
||||
else:
|
||||
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
|
||||
with open(filename, "r") as f:
|
||||
with open("./artefacts/" + filename, "r") as f:
|
||||
transcription_text = f.read()
|
||||
|
||||
# python_mask = np.array(PIL.Image.open("download1.png"))
|
||||
@@ -199,6 +200,6 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
||||
transform=st.Scalers.dense_rank
|
||||
)
|
||||
if real_time:
|
||||
open('./real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||
else:
|
||||
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||
@@ -137,10 +137,10 @@ def main():
|
||||
for chunk in whisper_result["chunks"]:
|
||||
transcript_text += chunk["text"]
|
||||
|
||||
with open("transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
|
||||
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
|
||||
transcript_file.write(transcript_text)
|
||||
|
||||
with open("transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
|
||||
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
|
||||
transcript_file_timestamps.write(str(whisper_result))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user