diff --git a/server_multithreaded.py b/server_multithreaded.py index 39d47ae2..e656a892 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -1,23 +1,26 @@ import asyncio +import configparser import datetime import io import json import logging import os -import sys import threading import uuid import wave from concurrent.futures import ThreadPoolExecutor -from sortedcontainers import SortedDict -import configparser + import jax.numpy as jnp -from aiohttp import web +from aiohttp import webq + from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import (MediaRelay) from av import AudioFifo +from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline +from utils.server_utils import Mutex + ROOT = os.path.dirname(__file__) config = configparser.ConfigParser() @@ -46,14 +49,14 @@ total_bytes_handled = 0 executor = ThreadPoolExecutor() -frame_lock = threading.Lock() -total_bytes_handled_lock = threading.Lock() +frame_lock = Mutex(audio_buffer) + def channel_log(channel, t, message): print("channel(%s) %s %s" % (channel.label, t, message)) + def thread_queue_channel_send(): - print("M-thread created") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: @@ -62,25 +65,15 @@ def thread_queue_channel_send(): if message: del sorted_message_queue[least_time] data_channel.send(message) - print("M-thread sent message to client") - with total_bytes_handled_lock: - print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time) except Exception as e: print("Exception", str(e)) pass loop.run_forever() -# async def channel_send(channel, message): -# channel_log(channel, ">", message) -# if channel and message: -# channel.send(message) -def get_transcription(local_thread_id): - # Block 1 - print("T-thread -> ", str(local_thread_id) , "created") - global frame_lock +def get_transcription(): while True: - with frame_lock: + with frame_lock.lock() as audio_buffer: frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False) if not frames: transcribe = False @@ -89,7 +82,6 @@ def get_transcription(local_thread_id): if transcribe: try: - print("T-thread ", str(local_thread_id), "is transcribing") sorted_message_queue[frames[0].time] = None out_file = io.BytesIO() wf = wave.open(out_file, "wb") @@ -102,10 +94,6 @@ def get_transcription(local_thread_id): wf.close() whisper_result = pipeline(out_file.getvalue()) - - global total_bytes_handled - with total_bytes_handled_lock: - total_bytes_handled += sys.getsizeof(wf) item = {'text': whisper_result["text"], 'start_time': str(frames[0].time), 'time': str(datetime.datetime.now()) @@ -115,6 +103,7 @@ def get_transcription(local_thread_id): except Exception as e: print("Exception -> ", str(e)) + class AudioStreamTrack(MediaStreamTrack): """ A video stream track that transforms frames from an another track. @@ -127,7 +116,6 @@ class AudioStreamTrack(MediaStreamTrack): self.track = track async def recv(self): - # print("Awaiting track in server") frame = await self.track.recv() audio_buffer.write(frame) return frame @@ -136,7 +124,7 @@ class AudioStreamTrack(MediaStreamTrack): def start_messaging_thread(): message_thread = threading.Thread(target=thread_queue_channel_send) message_thread.start() - # message_thread.join() + def start_transcription_thread(max_threads): t_threads = [] @@ -145,12 +133,9 @@ def start_transcription_thread(max_threads): t_threads.append(t_thread) t_thread.start() - # for t_thread in t_threads: - # t_thread.join() async def offer(request): params = await request.json() - print("Request received") offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) pc = RTCPeerConnection() @@ -185,11 +170,8 @@ async def offer(request): @pc.on("track") def on_track(track): - print("Track %s received", track.kind) log_info("Track %s received", track.kind) - # Trials to listen to the correct track pc.addTrack(AudioStreamTrack(relay.subscribe(track))) - # pc.addTrack(AudioStreamTrack(track)) # handle offer await pc.setRemoteDescription(offer) @@ -197,7 +179,6 @@ async def offer(request): # send answer answer = await pc.createAnswer() await pc.setLocalDescription(answer) - print("Response sent") return web.Response( content_type="application/json", text=json.dumps( @@ -207,7 +188,6 @@ async def offer(request): async def on_shutdown(app): - # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() @@ -221,5 +201,3 @@ if __name__ == "__main__": web.run_app( app, access_log=None, host="127.0.0.1", port=1250 ) - - diff --git a/stream_client.py b/stream_client.py index 508f3092..d660f079 100644 --- a/stream_client.py +++ b/stream_client.py @@ -1,16 +1,17 @@ +import ast import asyncio +import configparser import logging import time import uuid -import threading -import configparser + import httpx import pyaudio import requests -import ast import stamina from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) + from utils.server_utils import Mutex logger = logging.getLogger("pc") @@ -19,13 +20,14 @@ file_lock = Mutex(open("test_sm_6.txt", "a")) config = configparser.ConfigParser() config.read('config.ini') + class StreamClient: def __init__( - self, - signaling, - url="http://127.0.0.1:1250", - play_from=None, - ping_pong=False + self, + signaling, + url="http://127.0.0.1:1250", + play_from=None, + ping_pong=False ): self.signaling = signaling self.server_url = url @@ -103,7 +105,6 @@ class StreamClient: channel = pc.createDataChannel("data-channel") self.channel_log(channel, "-", "created by local party") - async def send_pings(): while True: self.channel_send(channel, "ping %d" % self.current_stamp()) diff --git a/utils/file_utilities.py b/utils/file_utilities.py index 6a4a4e40..e6c1fe9a 100644 --- a/utils/file_utilities.py +++ b/utils/file_utilities.py @@ -1,6 +1,7 @@ +import configparser + import boto3 import botocore -import configparser from loguru import logger config = configparser.ConfigParser() @@ -12,6 +13,7 @@ s3 = boto3.client('s3', aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"]) + def upload_files(files_to_upload): """ Upload a list of files to the configured S3 bucket @@ -45,6 +47,7 @@ def download_files(files_to_download): if __name__ == "__main__": import sys + if sys.argv[1] == "download": download_files([sys.argv[2]]) elif sys.argv[1] == "upload": diff --git a/utils/server_utils.py b/utils/server_utils.py index 2e46e094..5236a67d 100644 --- a/utils/server_utils.py +++ b/utils/server_utils.py @@ -1,9 +1,10 @@ import asyncio -from functools import partial import contextlib +from functools import partial from threading import Lock from typing import ContextManager, Generic, TypeVar + def run_in_executor(func, *args, executor=None, **kwargs): callback = partial(func, *args, **kwargs) loop = asyncio.get_event_loop() @@ -11,6 +12,8 @@ def run_in_executor(func, *args, executor=None, **kwargs): T = TypeVar("T") + + class Mutex(Generic[T]): def __init__(self, value: T): self.__value = value @@ -22,4 +25,4 @@ class Mutex(Generic[T]): try: yield self.__value finally: - self.__lock.release() \ No newline at end of file + self.__lock.release() diff --git a/utils/text_utilities.py b/utils/text_utilities.py index dbd6c6cc..0515ba36 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -1,23 +1,27 @@ -import torch import configparser + import nltk -from transformers import BartTokenizer, BartForConditionalGeneration +import torch from loguru import logger from nltk.corpus import stopwords -from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize +from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity +from transformers import BartTokenizer, BartForConditionalGeneration + nltk.download('punkt', quiet=True) config = configparser.ConfigParser() config.read('config.ini') + def preprocess_sentence(sentence): stop_words = set(stopwords.words('english')) tokens = word_tokenize(sentence.lower()) tokens = [token for token in tokens if token.isalnum() and token not in stop_words] return ' '.join(tokens) + def compute_similarity(sent1, sent2): """ Compute the similarity @@ -28,6 +32,7 @@ def compute_similarity(sent1, sent2): return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] return 0.0 + def remove_almost_alike_sentences(sentences, threshold=0.7): num_sentences = len(sentences) removed_indices = set() @@ -55,12 +60,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7): filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] return filtered_sentences + def remove_outright_duplicate_sentences_from_chunk(chunk): chunk_text = chunk["text"] sentences = nltk.sent_tokenize(chunk_text) nonduplicate_sentences = list(dict.fromkeys(sentences)) return nonduplicate_sentences + def remove_whisper_repetitive_hallucination(nonduplicate_sentences): chunk_sentences = [] @@ -80,6 +87,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): chunk_sentences.append(temp_result) return chunk_sentences + def post_process_transcription(whisper_result): transcript_text = "" for chunk in whisper_result["chunks"]: @@ -107,12 +115,13 @@ def summarize_chunks(chunks, tokenizer, model): input_ids = input_ids.to(device) with torch.no_grad(): summary_ids = model.generate(input_ids, - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summaries.append(summary) return summaries + def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): """ Split text into smaller chunks. @@ -132,6 +141,7 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"]) chunks.append(current_chunk.strip()) return chunks + def summarize(transcript_text, timestamp, real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/utils/viz_utilities.py b/utils/viz_utilities.py index f6a2baf7..e3e19a5d 100644 --- a/utils/viz_utilities.py +++ b/utils/viz_utilities.py @@ -1,15 +1,16 @@ -import matplotlib.pyplot as plt -from wordcloud import WordCloud, STOPWORDS -from nltk.corpus import stopwords -import collections -import spacy -import os -from pathlib import Path -import pickle import ast +import collections +import configparser +import os +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt import pandas as pd import scattertext as st -import configparser +import spacy +from nltk.corpus import stopwords +from wordcloud import WordCloud, STOPWORDS config = configparser.ConfigParser() config.read('config.ini') @@ -29,7 +30,7 @@ def create_wordcloud(timestamp, real_time=False): if real_time: filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" else: - filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" with open("./artefacts/" + filename, "r") as f: transcription_text = f.read() @@ -202,4 +203,4 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): if real_time: open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) else: - open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) \ No newline at end of file + open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) diff --git a/whisjax.py b/whisjax.py index bee4b6d8..bb98d1af 100644 --- a/whisjax.py +++ b/whisjax.py @@ -6,25 +6,24 @@ import argparse import configparser -import jax.numpy as jnp +import os +import re +import subprocess +import tempfile +from datetime import datetime +from urllib.parse import urlparse +import jax.numpy as jnp import moviepy.editor import moviepy.editor import nltk -import os -import subprocess -import re -import tempfile -from loguru import logger import yt_dlp as youtube_dl - -from urllib.parse import urlparse +from loguru import logger from whisper_jax import FlaxWhisperPipline -from datetime import datetime from utils.file_utilities import upload_files, download_files -from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz from utils.text_utilities import summarize, post_process_transcription +from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) @@ -36,6 +35,7 @@ config.read('config.ini') WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] NOW = datetime.now() + def init_argparse() -> argparse.ArgumentParser: """ Parse the CLI arguments @@ -52,7 +52,6 @@ def init_argparse() -> argparse.ArgumentParser: return parser - def main(): parser = init_argparse() args = parser.parse_args() @@ -140,10 +139,10 @@ def main(): with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file: transcript_file.write(transcript_text) - with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps: + with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", + "w") as transcript_file_timestamps: transcript_file_timestamps.write(str(whisper_result)) - logger.info("Creating word cloud") create_wordcloud(NOW) diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 2d299bb0..8db8444d 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -1,18 +1,20 @@ #!/usr/bin/env python3 import configparser -import pyaudio -from whisper_jax import FlaxWhisperPipline -from pynput import keyboard -import jax.numpy as jnp +import time import wave from datetime import datetime -from utils.file_utilities import upload_files -from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz -from utils.text_utilities import summarize, post_process_transcription + +import jax.numpy as jnp +import pyaudio from loguru import logger -import time +from pynput import keyboard from termcolor import colored +from whisper_jax import FlaxWhisperPipline + +from utils.file_utilities import upload_files +from utils.text_utilities import summarize, post_process_transcription +from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz config = configparser.ConfigParser() config.read('config.ini')