From b7fbfb2a5456ae11114d2db491bbe3b68d9726c4 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 11:06:27 +0530 Subject: [PATCH] refactor --- client.py | 12 +++++---- reflector-local/0-reflector-local.py | 7 +++--- reflector-local/1-transcript-generator.py | 11 ++++++--- reflector-local/2-agenda-transcript-diff.py | 6 ++++- reflector-local/3-transcript-summarizer.py | 8 ++++++ reflector-local/whisper_summarizer_bart.py | 13 +++++++--- server_executor_cleaned.py | 27 +++------------------ server_multithreaded.py | 4 +-- stream_client.py | 2 +- utils/file_utils.py | 1 + utils/text_utilities.py | 3 ++- whisjax.py | 2 +- whisjax_realtime.py | 2 +- 13 files changed, 54 insertions(+), 44 deletions(-) diff --git a/client.py b/client.py index 709d9f44..4816ea41 100644 --- a/client.py +++ b/client.py @@ -1,12 +1,14 @@ import argparse import asyncio import signal +from utils.log_utils import logger from aiortc.contrib.signaling import (add_signaling_arguments, create_signaling) from stream_client import StreamClient + async def main(): parser = argparse.ArgumentParser(description="Data channels ping/pong") @@ -35,17 +37,17 @@ async def main(): async def shutdown(signal, loop): """Cleanup tasks tied to the service's shutdown.""" - logging.info(f"Received exit signal {signal.name}...") - logging.info("Closing database connections") - logging.info("Nacking outstanding messages") + logger.info(f"Received exit signal {signal.name}...") + logger.info("Closing database connections") + logger.info("Nacking outstanding messages") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] - logging.info(f"Cancelling {len(tasks)} outstanding tasks") + logger.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) - logging.info(f"Flushing metrics") + logger.info(f"Flushing metrics") loop.stop() signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) diff --git a/reflector-local/0-reflector-local.py b/reflector-local/0-reflector-local.py index 4b93b408..4d5cebda 100644 --- a/reflector-local/0-reflector-local.py +++ b/reflector-local/0-reflector-local.py @@ -1,10 +1,11 @@ import os import subprocess import sys + from loguru import logger # Get the input file name from the command line argument -input_file = sys.argv[1] +input_file = sys.argv[1] # example use: python 0-reflector-local.py input.m4a agenda.txt # Get the agenda file name from the command line argument if provided @@ -21,7 +22,7 @@ if not os.path.exists(agenda_file): # Check if the input file is .m4a, if so convert to .mp4 if input_file.endswith(".m4a"): subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"]) - input_file = f"{input_file}.mp4" + input_file = f"{input_file}.mp4" # Run the first script to generate the transcript subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"]) @@ -30,4 +31,4 @@ subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_fil subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"]) # Run the third script to summarize the transcript -subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"]) \ No newline at end of file +subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"]) diff --git a/reflector-local/1-transcript-generator.py b/reflector-local/1-transcript-generator.py index 100eeae3..e19c41da 100755 --- a/reflector-local/1-transcript-generator.py +++ b/reflector-local/1-transcript-generator.py @@ -1,11 +1,13 @@ import argparse import os + import moviepy.editor -from loguru import logger import whisper +from loguru import logger WHISPER_MODEL_SIZE = "base" + def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( usage="%(prog)s ", @@ -15,6 +17,7 @@ def init_argparse() -> argparse.ArgumentParser: parser.add_argument("output", help="Output file path") return parser + def main(): import sys sys.setrecursionlimit(10000) @@ -26,10 +29,11 @@ def main(): logger.info(f"Processing file: {media_file}") # Check if the media file is a valid audio or video file - if os.path.isfile(media_file) and not media_file.endswith(('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')): + if os.path.isfile(media_file) and not media_file.endswith( + ('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')): logger.error(f"Invalid file format: {media_file}") return - + # If the media file we just retrieved is an audio file then skip extraction step audio_filename = media_file logger.info(f"Found audio-only file, skipping audio extraction") @@ -53,5 +57,6 @@ def main(): transcript_file.write(whisper_result["text"]) transcript_file.close() + if __name__ == "__main__": main() diff --git a/reflector-local/2-agenda-transcript-diff.py b/reflector-local/2-agenda-transcript-diff.py index 4972e3d3..30886dc0 100644 --- a/reflector-local/2-agenda-transcript-diff.py +++ b/reflector-local/2-agenda-transcript-diff.py @@ -1,7 +1,9 @@ import argparse + import spacy from loguru import logger + # Define the paths for agenda and transcription files def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( @@ -11,6 +13,8 @@ def init_argparse() -> argparse.ArgumentParser: parser.add_argument("agenda", help="Location of the agenda file") parser.add_argument("transcription", help="Location of the transcription file") return parser + + args = init_argparse().parse_args() agenda_path = args.agenda transcription_path = args.transcription @@ -19,7 +23,7 @@ transcription_path = args.transcription spaCy_model = "en_core_web_md" nlp = spacy.load(spaCy_model) nlp.add_pipe('sentencizer') -logger.info("Loaded spaCy model " + spaCy_model ) +logger.info("Loaded spaCy model " + spaCy_model) # Load the agenda with open(agenda_path, "r") as f: diff --git a/reflector-local/3-transcript-summarizer.py b/reflector-local/3-transcript-summarizer.py index 4a58c198..58a75451 100644 --- a/reflector-local/3-transcript-summarizer.py +++ b/reflector-local/3-transcript-summarizer.py @@ -1,11 +1,14 @@ import argparse + import nltk + nltk.download('stopwords') from nltk.corpus import stopwords from nltk.tokenize import word_tokenize, sent_tokenize from heapq import nlargest from loguru import logger + # Function to initialize the argument parser def init_argparse(): parser = argparse.ArgumentParser( @@ -17,12 +20,14 @@ def init_argparse(): parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary") return parser + # Function to read the input transcript file def read_transcript(file_path): with open(file_path, "r") as file: transcript = file.read() return transcript + # Function to preprocess the text by removing stop words and special characters def preprocess_text(text): stop_words = set(stopwords.words('english')) @@ -30,6 +35,7 @@ def preprocess_text(text): words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words] return words + # Function to score each sentence based on the frequency of its words and return the top sentences def summarize_text(text, num_sentences): # Tokenize the text into sentences @@ -61,6 +67,7 @@ def summarize_text(text, num_sentences): return " ".join(summary) + def main(): # Initialize the argument parser and parse the arguments parser = init_argparse() @@ -82,5 +89,6 @@ def main(): logger.info("Summarization completed") + if __name__ == "__main__": main() diff --git a/reflector-local/whisper_summarizer_bart.py b/reflector-local/whisper_summarizer_bart.py index b0de87f7..4184fafe 100644 --- a/reflector-local/whisper_summarizer_bart.py +++ b/reflector-local/whisper_summarizer_bart.py @@ -1,15 +1,18 @@ import argparse import os import tempfile + import moviepy.editor +import nltk +import whisper from loguru import logger from transformers import BartTokenizer, BartForConditionalGeneration -import whisper -import nltk + nltk.download('punkt', quiet=True) WHISPER_MODEL_SIZE = "base" + def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( usage="%(prog)s [OPTIONS] ", @@ -30,6 +33,7 @@ def init_argparse() -> argparse.ArgumentParser: return parser + # NLTK chunking function def chunk_text(txt, max_chunk_length=500): "Split text into smaller chunks." @@ -45,6 +49,7 @@ def chunk_text(txt, max_chunk_length=500): chunks.append(current_chunk.strip()) return chunks + # BART summary function def summarize_chunks(chunks, tokenizer, model): summaries = [] @@ -56,6 +61,7 @@ def summarize_chunks(chunks, tokenizer, model): summaries.append(summary) return summaries + def main(): import sys sys.setrecursionlimit(10000) @@ -103,7 +109,7 @@ def main(): chunks = chunk_text(whisper_result['text']) logger.info( - f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable + f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable logger.info(f"Writing summary text in {args.language} to: {args.output}") with open(args.output, 'w') as f: @@ -114,5 +120,6 @@ def main(): logger.info("Summarization completed") + if __name__ == "__main__": main() diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index ffbed6cb..0a6dbe3f 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -2,7 +2,6 @@ import asyncio import datetime import io import json -from loguru import logger import sys import uuid import wave @@ -13,6 +12,7 @@ from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay from av import AudioFifo +from loguru import logger from whisper_jax import FlaxWhisperPipline from utils.server_utils import run_in_executor @@ -23,7 +23,9 @@ pcs = set() relay = MediaRelay() data_channel = None total_bytes_handled = 0 -pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16) +pipeline = FlaxWhisperPipline("openai/whisper-tiny", + dtype=jnp.float16, + batch_size=16) CHANNELS = 2 RATE = 48000 @@ -50,18 +52,6 @@ def channel_send(channel, message): def get_transcription(frames): - print("Transcribing..") - # samples = np.ndarray( - # np.concatenate([f.to_ndarray() for f in frames], axis=None), - # dtype=np.float32, - # ) - # whisper_result = pipeline( - # { - # "array": samples, - # "sampling_rate": 48000, - # }, - # return_timestamps=True, - # ) out_file = io.BytesIO() wf = wave.open(out_file, "wb") wf.setnchannels(CHANNELS) @@ -108,7 +98,6 @@ class AudioStreamTrack(MediaStreamTrack): async def offer(request): params = await request.json() - print("Request received") offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) pc = RTCPeerConnection() @@ -132,7 +121,6 @@ async def offer(request): channel_log(channel, "<", message) if isinstance(message, str) and message.startswith("ping"): - # reply channel_send(channel, "pong" + message[4:]) @pc.on("connectionstatechange") @@ -144,19 +132,13 @@ async def offer(request): @pc.on("track") def on_track(track): - print("Track %s received" % track.kind) log_info("Track %s received", track.kind) - # Trials to listen to the correct track pc.addTrack(AudioStreamTrack(relay.subscribe(track))) - # pc.addTrack(AudioStreamTrack(track)) - # handle offer await pc.setRemoteDescription(offer) - # send answer answer = await pc.createAnswer() await pc.setLocalDescription(answer) - print("Response sent") return web.Response( content_type="application/json", text=json.dumps( @@ -166,7 +148,6 @@ async def offer(request): async def on_shutdown(app): - # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() diff --git a/server_multithreaded.py b/server_multithreaded.py index a1c74a68..9bb24031 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -3,21 +3,21 @@ import configparser import datetime import io import json -from utils.log_utils import logger import os import threading import uuid import wave from concurrent.futures import ThreadPoolExecutor +from aiohttp import web import jax.numpy as jnp - from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import (MediaRelay) from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline +from utils.log_utils import logger from utils.server_utils import Mutex ROOT = os.path.dirname(__file__) diff --git a/stream_client.py b/stream_client.py index 68fc4709..82f38e95 100644 --- a/stream_client.py +++ b/stream_client.py @@ -1,7 +1,6 @@ import ast import asyncio import configparser -from utils.log_utils import logger import time import uuid @@ -12,6 +11,7 @@ import stamina from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) +from utils.log_utils import logger from utils.server_utils import Mutex file_lock = Mutex(open("test_sm_6.txt", "a")) diff --git a/utils/file_utils.py b/utils/file_utils.py index 70108a49..d9fcc08f 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -2,6 +2,7 @@ import configparser import boto3 import botocore + from log_utils import logger config = configparser.ConfigParser() diff --git a/utils/text_utilities.py b/utils/text_utilities.py index cfa6c9dd..d67caf66 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -2,13 +2,14 @@ import configparser import nltk import torch -from log_utils import logger from nltk.corpus import stopwords from nltk.tokenize import word_tokenize from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import BartTokenizer, BartForConditionalGeneration +from log_utils import logger + nltk.download('punkt', quiet=True) config = configparser.ConfigParser() diff --git a/whisjax.py b/whisjax.py index 36dfb785..ebfe1056 100644 --- a/whisjax.py +++ b/whisjax.py @@ -18,10 +18,10 @@ import moviepy.editor import moviepy.editor import nltk import yt_dlp as youtube_dl -from utils.log_utils import logger from whisper_jax import FlaxWhisperPipline from utils.file_utils import upload_files, download_files +from utils.log_utils import logger from utils.text_utilities import summarize, post_process_transcription from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 24185186..60b06a8c 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -7,12 +7,12 @@ from datetime import datetime import jax.numpy as jnp import pyaudio -from utils.log_utils import logger from pynput import keyboard from termcolor import colored from whisper_jax import FlaxWhisperPipline from utils.file_utils import upload_files +from utils.log_utils import logger from utils.text_utilities import summarize, post_process_transcription from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz