From 58c9cdf6766b8165387e46a07b7e0386dd87a9eb Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 11:01:22 +0530 Subject: [PATCH 1/9] add singleton logging --- .gitignore | 1 + client.py | 4 ---- server_executor_cleaned.py | 4 +--- server_multithreaded.py | 5 +---- stream_client.py | 3 +-- utils/{file_utilities.py => file_utils.py} | 2 +- utils/log_utils.py | 14 ++++++++++++++ utils/text_utilities.py | 2 +- whisjax.py | 4 ++-- whisjax_realtime.py | 4 ++-- 10 files changed, 24 insertions(+), 19 deletions(-) rename utils/{file_utilities.py => file_utils.py} (97%) create mode 100644 utils/log_utils.py diff --git a/.gitignore b/.gitignore index 5c2931cb..fd3e8b20 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,5 @@ test_samples/ *.mp3 *.m4a .DS_Store/ +.DS_Store .vscode/ diff --git a/client.py b/client.py index 98db7922..709d9f44 100644 --- a/client.py +++ b/client.py @@ -1,6 +1,5 @@ import argparse import asyncio -import logging import signal from aiortc.contrib.signaling import (add_signaling_arguments, @@ -8,9 +7,6 @@ from aiortc.contrib.signaling import (add_signaling_arguments, from stream_client import StreamClient -logger = logging.getLogger("pc") - - async def main(): parser = argparse.ArgumentParser(description="Data channels ping/pong") diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index ca42a1b0..ffbed6cb 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -2,7 +2,7 @@ import asyncio import datetime import io import json -import logging +from loguru import logger import sys import uuid import wave @@ -17,8 +17,6 @@ from whisper_jax import FlaxWhisperPipline from utils.server_utils import run_in_executor -logger = logging.getLogger(__name__) - transcription = "" pcs = set() diff --git a/server_multithreaded.py b/server_multithreaded.py index e656a892..a1c74a68 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -3,7 +3,7 @@ import configparser import datetime import io import json -import logging +from utils.log_utils import logger import os import threading import uuid @@ -11,7 +11,6 @@ import wave from concurrent.futures import ThreadPoolExecutor import jax.numpy as jnp -from aiohttp import webq from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import (MediaRelay) @@ -27,8 +26,6 @@ config = configparser.ConfigParser() config.read('config.ini') WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] - -logger = logging.getLogger("pc") pcs = set() relay = MediaRelay() data_channel = None diff --git a/stream_client.py b/stream_client.py index d660f079..68fc4709 100644 --- a/stream_client.py +++ b/stream_client.py @@ -1,7 +1,7 @@ import ast import asyncio import configparser -import logging +from utils.log_utils import logger import time import uuid @@ -14,7 +14,6 @@ from aiortc.contrib.media import (MediaPlayer, MediaRelay) from utils.server_utils import Mutex -logger = logging.getLogger("pc") file_lock = Mutex(open("test_sm_6.txt", "a")) config = configparser.ConfigParser() diff --git a/utils/file_utilities.py b/utils/file_utils.py similarity index 97% rename from utils/file_utilities.py rename to utils/file_utils.py index e6c1fe9a..70108a49 100644 --- a/utils/file_utilities.py +++ b/utils/file_utils.py @@ -2,7 +2,7 @@ import configparser import boto3 import botocore -from loguru import logger +from log_utils import logger config = configparser.ConfigParser() config.read('config.ini') diff --git a/utils/log_utils.py b/utils/log_utils.py new file mode 100644 index 00000000..3b874363 --- /dev/null +++ b/utils/log_utils.py @@ -0,0 +1,14 @@ +from loguru import logger + + +class SingletonLogger: + __instance = None + + @staticmethod + def get_logger(): + if not SingletonLogger.__instance: + SingletonLogger.__instance = logger + return SingletonLogger.__instance + + +logger = SingletonLogger.get_logger() diff --git a/utils/text_utilities.py b/utils/text_utilities.py index 0515ba36..cfa6c9dd 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -2,7 +2,7 @@ import configparser import nltk import torch -from loguru import logger +from log_utils import logger from nltk.corpus import stopwords from nltk.tokenize import word_tokenize from sklearn.feature_extraction.text import TfidfVectorizer diff --git a/whisjax.py b/whisjax.py index bb98d1af..36dfb785 100644 --- a/whisjax.py +++ b/whisjax.py @@ -18,10 +18,10 @@ import moviepy.editor import moviepy.editor import nltk import yt_dlp as youtube_dl -from loguru import logger +from utils.log_utils import logger from whisper_jax import FlaxWhisperPipline -from utils.file_utilities import upload_files, download_files +from utils.file_utils import upload_files, download_files from utils.text_utilities import summarize, post_process_transcription from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 8db8444d..24185186 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -7,12 +7,12 @@ from datetime import datetime import jax.numpy as jnp import pyaudio -from loguru import logger +from utils.log_utils import logger from pynput import keyboard from termcolor import colored from whisper_jax import FlaxWhisperPipline -from utils.file_utilities import upload_files +from utils.file_utils import upload_files from utils.text_utilities import summarize, post_process_transcription from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz From b7fbfb2a5456ae11114d2db491bbe3b68d9726c4 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 11:06:27 +0530 Subject: [PATCH 2/9] refactor --- client.py | 12 +++++---- reflector-local/0-reflector-local.py | 7 +++--- reflector-local/1-transcript-generator.py | 11 ++++++--- reflector-local/2-agenda-transcript-diff.py | 6 ++++- reflector-local/3-transcript-summarizer.py | 8 ++++++ reflector-local/whisper_summarizer_bart.py | 13 +++++++--- server_executor_cleaned.py | 27 +++------------------ server_multithreaded.py | 4 +-- stream_client.py | 2 +- utils/file_utils.py | 1 + utils/text_utilities.py | 3 ++- whisjax.py | 2 +- whisjax_realtime.py | 2 +- 13 files changed, 54 insertions(+), 44 deletions(-) diff --git a/client.py b/client.py index 709d9f44..4816ea41 100644 --- a/client.py +++ b/client.py @@ -1,12 +1,14 @@ import argparse import asyncio import signal +from utils.log_utils import logger from aiortc.contrib.signaling import (add_signaling_arguments, create_signaling) from stream_client import StreamClient + async def main(): parser = argparse.ArgumentParser(description="Data channels ping/pong") @@ -35,17 +37,17 @@ async def main(): async def shutdown(signal, loop): """Cleanup tasks tied to the service's shutdown.""" - logging.info(f"Received exit signal {signal.name}...") - logging.info("Closing database connections") - logging.info("Nacking outstanding messages") + logger.info(f"Received exit signal {signal.name}...") + logger.info("Closing database connections") + logger.info("Nacking outstanding messages") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] - logging.info(f"Cancelling {len(tasks)} outstanding tasks") + logger.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) - logging.info(f"Flushing metrics") + logger.info(f"Flushing metrics") loop.stop() signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) diff --git a/reflector-local/0-reflector-local.py b/reflector-local/0-reflector-local.py index 4b93b408..4d5cebda 100644 --- a/reflector-local/0-reflector-local.py +++ b/reflector-local/0-reflector-local.py @@ -1,10 +1,11 @@ import os import subprocess import sys + from loguru import logger # Get the input file name from the command line argument -input_file = sys.argv[1] +input_file = sys.argv[1] # example use: python 0-reflector-local.py input.m4a agenda.txt # Get the agenda file name from the command line argument if provided @@ -21,7 +22,7 @@ if not os.path.exists(agenda_file): # Check if the input file is .m4a, if so convert to .mp4 if input_file.endswith(".m4a"): subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"]) - input_file = f"{input_file}.mp4" + input_file = f"{input_file}.mp4" # Run the first script to generate the transcript subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"]) @@ -30,4 +31,4 @@ subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_fil subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"]) # Run the third script to summarize the transcript -subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"]) \ No newline at end of file +subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"]) diff --git a/reflector-local/1-transcript-generator.py b/reflector-local/1-transcript-generator.py index 100eeae3..e19c41da 100755 --- a/reflector-local/1-transcript-generator.py +++ b/reflector-local/1-transcript-generator.py @@ -1,11 +1,13 @@ import argparse import os + import moviepy.editor -from loguru import logger import whisper +from loguru import logger WHISPER_MODEL_SIZE = "base" + def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( usage="%(prog)s ", @@ -15,6 +17,7 @@ def init_argparse() -> argparse.ArgumentParser: parser.add_argument("output", help="Output file path") return parser + def main(): import sys sys.setrecursionlimit(10000) @@ -26,10 +29,11 @@ def main(): logger.info(f"Processing file: {media_file}") # Check if the media file is a valid audio or video file - if os.path.isfile(media_file) and not media_file.endswith(('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')): + if os.path.isfile(media_file) and not media_file.endswith( + ('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')): logger.error(f"Invalid file format: {media_file}") return - + # If the media file we just retrieved is an audio file then skip extraction step audio_filename = media_file logger.info(f"Found audio-only file, skipping audio extraction") @@ -53,5 +57,6 @@ def main(): transcript_file.write(whisper_result["text"]) transcript_file.close() + if __name__ == "__main__": main() diff --git a/reflector-local/2-agenda-transcript-diff.py b/reflector-local/2-agenda-transcript-diff.py index 4972e3d3..30886dc0 100644 --- a/reflector-local/2-agenda-transcript-diff.py +++ b/reflector-local/2-agenda-transcript-diff.py @@ -1,7 +1,9 @@ import argparse + import spacy from loguru import logger + # Define the paths for agenda and transcription files def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( @@ -11,6 +13,8 @@ def init_argparse() -> argparse.ArgumentParser: parser.add_argument("agenda", help="Location of the agenda file") parser.add_argument("transcription", help="Location of the transcription file") return parser + + args = init_argparse().parse_args() agenda_path = args.agenda transcription_path = args.transcription @@ -19,7 +23,7 @@ transcription_path = args.transcription spaCy_model = "en_core_web_md" nlp = spacy.load(spaCy_model) nlp.add_pipe('sentencizer') -logger.info("Loaded spaCy model " + spaCy_model ) +logger.info("Loaded spaCy model " + spaCy_model) # Load the agenda with open(agenda_path, "r") as f: diff --git a/reflector-local/3-transcript-summarizer.py b/reflector-local/3-transcript-summarizer.py index 4a58c198..58a75451 100644 --- a/reflector-local/3-transcript-summarizer.py +++ b/reflector-local/3-transcript-summarizer.py @@ -1,11 +1,14 @@ import argparse + import nltk + nltk.download('stopwords') from nltk.corpus import stopwords from nltk.tokenize import word_tokenize, sent_tokenize from heapq import nlargest from loguru import logger + # Function to initialize the argument parser def init_argparse(): parser = argparse.ArgumentParser( @@ -17,12 +20,14 @@ def init_argparse(): parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary") return parser + # Function to read the input transcript file def read_transcript(file_path): with open(file_path, "r") as file: transcript = file.read() return transcript + # Function to preprocess the text by removing stop words and special characters def preprocess_text(text): stop_words = set(stopwords.words('english')) @@ -30,6 +35,7 @@ def preprocess_text(text): words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words] return words + # Function to score each sentence based on the frequency of its words and return the top sentences def summarize_text(text, num_sentences): # Tokenize the text into sentences @@ -61,6 +67,7 @@ def summarize_text(text, num_sentences): return " ".join(summary) + def main(): # Initialize the argument parser and parse the arguments parser = init_argparse() @@ -82,5 +89,6 @@ def main(): logger.info("Summarization completed") + if __name__ == "__main__": main() diff --git a/reflector-local/whisper_summarizer_bart.py b/reflector-local/whisper_summarizer_bart.py index b0de87f7..4184fafe 100644 --- a/reflector-local/whisper_summarizer_bart.py +++ b/reflector-local/whisper_summarizer_bart.py @@ -1,15 +1,18 @@ import argparse import os import tempfile + import moviepy.editor +import nltk +import whisper from loguru import logger from transformers import BartTokenizer, BartForConditionalGeneration -import whisper -import nltk + nltk.download('punkt', quiet=True) WHISPER_MODEL_SIZE = "base" + def init_argparse() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( usage="%(prog)s [OPTIONS] ", @@ -30,6 +33,7 @@ def init_argparse() -> argparse.ArgumentParser: return parser + # NLTK chunking function def chunk_text(txt, max_chunk_length=500): "Split text into smaller chunks." @@ -45,6 +49,7 @@ def chunk_text(txt, max_chunk_length=500): chunks.append(current_chunk.strip()) return chunks + # BART summary function def summarize_chunks(chunks, tokenizer, model): summaries = [] @@ -56,6 +61,7 @@ def summarize_chunks(chunks, tokenizer, model): summaries.append(summary) return summaries + def main(): import sys sys.setrecursionlimit(10000) @@ -103,7 +109,7 @@ def main(): chunks = chunk_text(whisper_result['text']) logger.info( - f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable + f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable logger.info(f"Writing summary text in {args.language} to: {args.output}") with open(args.output, 'w') as f: @@ -114,5 +120,6 @@ def main(): logger.info("Summarization completed") + if __name__ == "__main__": main() diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index ffbed6cb..0a6dbe3f 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -2,7 +2,6 @@ import asyncio import datetime import io import json -from loguru import logger import sys import uuid import wave @@ -13,6 +12,7 @@ from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay from av import AudioFifo +from loguru import logger from whisper_jax import FlaxWhisperPipline from utils.server_utils import run_in_executor @@ -23,7 +23,9 @@ pcs = set() relay = MediaRelay() data_channel = None total_bytes_handled = 0 -pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16) +pipeline = FlaxWhisperPipline("openai/whisper-tiny", + dtype=jnp.float16, + batch_size=16) CHANNELS = 2 RATE = 48000 @@ -50,18 +52,6 @@ def channel_send(channel, message): def get_transcription(frames): - print("Transcribing..") - # samples = np.ndarray( - # np.concatenate([f.to_ndarray() for f in frames], axis=None), - # dtype=np.float32, - # ) - # whisper_result = pipeline( - # { - # "array": samples, - # "sampling_rate": 48000, - # }, - # return_timestamps=True, - # ) out_file = io.BytesIO() wf = wave.open(out_file, "wb") wf.setnchannels(CHANNELS) @@ -108,7 +98,6 @@ class AudioStreamTrack(MediaStreamTrack): async def offer(request): params = await request.json() - print("Request received") offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) pc = RTCPeerConnection() @@ -132,7 +121,6 @@ async def offer(request): channel_log(channel, "<", message) if isinstance(message, str) and message.startswith("ping"): - # reply channel_send(channel, "pong" + message[4:]) @pc.on("connectionstatechange") @@ -144,19 +132,13 @@ async def offer(request): @pc.on("track") def on_track(track): - print("Track %s received" % track.kind) log_info("Track %s received", track.kind) - # Trials to listen to the correct track pc.addTrack(AudioStreamTrack(relay.subscribe(track))) - # pc.addTrack(AudioStreamTrack(track)) - # handle offer await pc.setRemoteDescription(offer) - # send answer answer = await pc.createAnswer() await pc.setLocalDescription(answer) - print("Response sent") return web.Response( content_type="application/json", text=json.dumps( @@ -166,7 +148,6 @@ async def offer(request): async def on_shutdown(app): - # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() diff --git a/server_multithreaded.py b/server_multithreaded.py index a1c74a68..9bb24031 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -3,21 +3,21 @@ import configparser import datetime import io import json -from utils.log_utils import logger import os import threading import uuid import wave from concurrent.futures import ThreadPoolExecutor +from aiohttp import web import jax.numpy as jnp - from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import (MediaRelay) from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline +from utils.log_utils import logger from utils.server_utils import Mutex ROOT = os.path.dirname(__file__) diff --git a/stream_client.py b/stream_client.py index 68fc4709..82f38e95 100644 --- a/stream_client.py +++ b/stream_client.py @@ -1,7 +1,6 @@ import ast import asyncio import configparser -from utils.log_utils import logger import time import uuid @@ -12,6 +11,7 @@ import stamina from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) +from utils.log_utils import logger from utils.server_utils import Mutex file_lock = Mutex(open("test_sm_6.txt", "a")) diff --git a/utils/file_utils.py b/utils/file_utils.py index 70108a49..d9fcc08f 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -2,6 +2,7 @@ import configparser import boto3 import botocore + from log_utils import logger config = configparser.ConfigParser() diff --git a/utils/text_utilities.py b/utils/text_utilities.py index cfa6c9dd..d67caf66 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -2,13 +2,14 @@ import configparser import nltk import torch -from log_utils import logger from nltk.corpus import stopwords from nltk.tokenize import word_tokenize from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import BartTokenizer, BartForConditionalGeneration +from log_utils import logger + nltk.download('punkt', quiet=True) config = configparser.ConfigParser() diff --git a/whisjax.py b/whisjax.py index 36dfb785..ebfe1056 100644 --- a/whisjax.py +++ b/whisjax.py @@ -18,10 +18,10 @@ import moviepy.editor import moviepy.editor import nltk import yt_dlp as youtube_dl -from utils.log_utils import logger from whisper_jax import FlaxWhisperPipline from utils.file_utils import upload_files, download_files +from utils.log_utils import logger from utils.text_utilities import summarize, post_process_transcription from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 24185186..60b06a8c 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -7,12 +7,12 @@ from datetime import datetime import jax.numpy as jnp import pyaudio -from utils.log_utils import logger from pynput import keyboard from termcolor import colored from whisper_jax import FlaxWhisperPipline from utils.file_utils import upload_files +from utils.log_utils import logger from utils.text_utilities import summarize, post_process_transcription from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz From 8e9cd6c56846396a9d617abd3dfcae5ed7f98053 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 12:09:30 +0530 Subject: [PATCH 3/9] code cleanup --- README.md | 125 ++++++++++++++++++++----------------- client.py | 30 ++++----- config.ini | 30 ++++----- requirements.txt | 4 +- server_executor_cleaned.py | 26 ++++---- server_multithreaded.py | 27 ++++---- stream_client.py | 12 ++-- utils/file_utils.py | 18 +++--- utils/log_utils.py | 4 ++ utils/run_utils.py | 66 ++++++++++++++++++++ utils/server_utils.py | 28 --------- utils/text_utilities.py | 8 +-- utils/viz_utilities.py | 28 ++++----- whisjax.py | 29 ++++----- whisjax_realtime.py | 29 +++++---- 15 files changed, 249 insertions(+), 215 deletions(-) create mode 100644 utils/run_utils.py delete mode 100644 utils/server_utils.py diff --git a/README.md b/README.md index 01cf5a87..b1cda0a0 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,34 @@ # Reflector -This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM. - -The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live. +This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting +panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM. +The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target +agenda/objectives to the actual discussion live. **S3 bucket:** Everything you need for S3 is already configured in config.ini. Only edit it if you need to change it deliberately. -S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where the -script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini). +S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where +the +script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini). For AWS S3 Web UI, + 1) Login to AWS management console. 2) Search for S3 in the search bar at the top. 3) Navigate to list the buckets under the current account, if needed and choose your bucket [```reflector-bucket```] 4) You should be able to see items in the bucket. You can upload/download files here directly. - -For CLI, +For CLI, Refer to the FILE UTIL section below. - **FILE UTIL MODULE:** -A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini. -Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the pipeline workflow in the script, you can do so by : +A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini. +Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the +pipeline workflow in the script, you can do so by : Upload: @@ -39,37 +41,37 @@ Download: If you want to access the S3 artefacts, from another machine, you can either use the python file_util with the commands mentioned above or simply use the GUI of AWS Management Console. - -To setup, +To setup, 1) Check values in config.ini file. Specifically add your OPENAI_APIKEY if you plan to use OpenAI API requests. -2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.] +2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in + Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.] NOTE: If you don't have portaudio installed already, run ```brew install portaudio``` 3) Run the script setup_depedencies.sh. - ``` chmod +x setup_dependencies.sh ``` + ``` chmod +x setup_dependencies.sh ``` - ``` sh setup_dependencies.sh ``` + ``` sh setup_dependencies.sh ``` - - ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU] - - ```ENV``` is : - - cpu -> JAX CPU installation +ENV refers to the intended environment for JAX. JAX is available in several +variants, [CPU | GPU | Colab TPU | Google Cloud TPU] - cuda11 -> JAX CUDA 11.x version +```ENV``` is : - cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```) +cpu -> JAX CPU installation + +cuda11 -> JAX CUDA 11.x version + +cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```) ```sh setup_dependencies.sh cuda12``` 4) If not already done, install ffmpeg. ```brew install ffmpeg``` -For NLTK SSL error, check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed) - +For NLTK SSL error, +check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed) 5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it. @@ -79,83 +81,92 @@ You can even run it on local file or a file in your configured S3 bucket. ``` python3 whisjax.py "startup.mp4"``` -The script will take care of a few cases like youtube file, local file, video file, audio-only file, +The script will take care of a few cases like youtube file, local file, video file, audio-only file, file in S3, etc. If local file is not present, it can automatically take the file from S3. **OFFLINE WORKFLOW:** -1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the script.If the source file is in +1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the + script.If the source file is in ```.m4a``` format, it will get converted to ```.mp4``` automatically. -2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the script is run. +2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the + script is run. This version of the pipeline compares covered agenda topics using agenda headers in the following format. - 1) ```agenda_topic : ``` -3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the - topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called - ```df_.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the + 1) ```agenda_topic : ``` +3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the + topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact + called + ```df_.pkl``` , you can load the df and choose different topics to plot. You can filter using certain + words to search for the transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the - interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named + interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, + named ```Viz_experiments.ipynb```. -4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the form of an interactive -HTML file, a sample word cloud and uploads them to the S3 bucket +4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the + form of an interactive + HTML file, a sample word cloud and uploads them to the S3 bucket 5) Additional artefacts pushed to S3: - 1) HTML visualization file - 2) pandas df in pickle format for others to collaborate and make their own visualizations - 3) Summary, transcript and transcript with timestamps file in text format. + 1) HTML visualization file + 2) pandas df in pickle format for others to collaborate and make their own visualizations + 3) Summary, transcript and transcript with timestamps file in text format. - The script also creates 2 types of mappings. - 1) Timestamp -> The top 2 matched agenda topic - 2) Topic -> All matched timestamps in the transcription - -Other visualizations can be planned based on available artefacts or new ones can be created. Refer the section ```Viz-experiments```. + The script also creates 2 types of mappings. + 1) Timestamp -> The top 2 matched agenda topic + 2) Topic -> All matched timestamps in the transcription +Other visualizations can be planned based on available artefacts or new ones can be created. Refer the +section ```Viz-experiments```. **Visualization experiments:** -This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated from the -pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment libraries and +This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated +from the +pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment +libraries and visualizations on top of the metadata. **WHISPER-JAX REALTIME TRANSCRIPTION PIPELINE:** -We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are -a few pre-requisites before you run it on your local machine. The instructions are for +We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are +a few pre-requisites before you run it on your local machine. The instructions are for configuring on a MacOS. We need to way to route audio from an application opened via the browser, ex. "Whereby" and audio from your local -microphone input which you will be using for speaking. We use [Blackhole](https://github.com/ExistentialAudio/BlackHole). +microphone input which you will be using for speaking. We +use [Blackhole](https://github.com/ExistentialAudio/BlackHole). 1) Install Blackhole-2ch (2 ch is enough) by 1 of 2 options listed. 2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and local microphone input. - Be sure to mirror the settings given ![here](./images/aggregate_input.png) + Be sure to mirror the settings given ![here](./images/aggregate_input.png) 3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device) - + Refer ![here](./images/multi-output.png) 4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME``` 5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and -Input tabs. + Input tabs. -6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to -and the output should be fed back to your specified output devices if everything is configured properly. Check this -before trying out the trial. +6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen + to + and the output should be fed back to your specified output devices if everything is configured properly. Check this + before trying out the trial. **Permissions:** -You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in +You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in ```System Preferences -> Privacy & Security -> Microphone```, ```System Preferences -> Privacy & Security -> Accessibility```, ```System Preferences -> Privacy & Security -> Input Monitoring```. -From the reflector root folder, +From the reflector root folder, run ```python3 whisjax_realtime.py``` The transcription text should be written to ```real_time_transcription_.txt```. - NEXT STEPS: 1) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end diff --git a/client.py b/client.py index 4816ea41..d6393712 100644 --- a/client.py +++ b/client.py @@ -1,33 +1,33 @@ import argparse import asyncio import signal -from utils.log_utils import logger from aiortc.contrib.signaling import (add_signaling_arguments, create_signaling) from stream_client import StreamClient +from utils.log_utils import logger async def main(): parser = argparse.ArgumentParser(description="Data channels ping/pong") parser.add_argument( - "--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer" + "--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer" ) parser.add_argument( - "--ping-pong", - help="Benchmark data channel with ping pong", - type=eval, - choices=[True, False], - default="False", + "--ping-pong", + help="Benchmark data channel with ping pong", + type=eval, + choices=[True, False], + default="False", ) parser.add_argument( - "--play-from", - type=str, - default="", + "--play-from", + type=str, + default="", ) add_signaling_arguments(parser) @@ -54,14 +54,14 @@ async def main(): loop = asyncio.get_event_loop() for s in signals: loop.add_signal_handler( - s, lambda s=s: asyncio.create_task(shutdown(s, loop))) + s, lambda s=s: asyncio.create_task(shutdown(s, loop))) # Init client sc = StreamClient( - signaling=signaling, - url=args.url, - play_from=args.play_from, - ping_pong=args.ping_pong + signaling=signaling, + url=args.url, + play_from=args.play_from, + ping_pong=args.ping_pong ) await sc.start() print("Stream client started") diff --git a/config.ini b/config.ini index c0a41bbf..0092129f 100644 --- a/config.ini +++ b/config.ini @@ -1,22 +1,22 @@ [DEFAULT] # Set exception rule for OpenMP error to allow duplicate lib initialization -KMP_DUPLICATE_LIB_OK=TRUE +KMP_DUPLICATE_LIB_OK = TRUE # Export OpenAI API Key -OPENAI_APIKEY= +OPENAI_APIKEY = # Export Whisper Model Size -WHISPER_MODEL_SIZE=tiny -WHISPER_REAL_TIME_MODEL_SIZE=tiny +WHISPER_MODEL_SIZE = tiny +WHISPER_REAL_TIME_MODEL_SIZE = tiny # AWS config -AWS_ACCESS_KEY=***REMOVED*** -AWS_SECRET_KEY=***REMOVED*** -BUCKET_NAME='reflector-bucket' +AWS_ACCESS_KEY = ***REMOVED*** +AWS_SECRET_KEY = ***REMOVED*** +BUCKET_NAME = 'reflector-bucket' # Summarizer config -SUMMARY_MODEL=facebook/bart-large-cnn -INPUT_ENCODING_MAX_LENGTH=1024 -MAX_LENGTH=2048 -BEAM_SIZE=6 -MAX_CHUNK_LENGTH=1024 -SUMMARIZE_USING_CHUNKS=YES +SUMMARY_MODEL = facebook/bart-large-cnn +INPUT_ENCODING_MAX_LENGTH = 1024 +MAX_LENGTH = 2048 +BEAM_SIZE = 6 +MAX_CHUNK_LENGTH = 1024 +SUMMARIZE_USING_CHUNKS = YES # Audio device -BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=aggregator -AV_FOUNDATION_DEVICE_ID=2 \ No newline at end of file +BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME = aggregator +AV_FOUNDATION_DEVICE_ID = 2 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c983bb1a..23b8e38d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ networkx==3.1 numba==0.57.0 numpy==1.24.3 openai==0.27.7 -openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8 +openai-whisper@ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8 Pillow==9.5.0 proglog==0.1.10 pytube==15.0.0 @@ -56,5 +56,5 @@ cached_property==1.5.2 stamina==23.1.0 httpx==0.24.1 sortedcontainers==2.4.0 -openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8 +openai-whisper@ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8 https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index 0a6dbe3f..1a40b4af 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -15,7 +15,7 @@ from av import AudioFifo from loguru import logger from whisper_jax import FlaxWhisperPipline -from utils.server_utils import run_in_executor +from utils.run_utils import run_in_executor transcription = "" @@ -44,10 +44,10 @@ def channel_send(channel, message): if channel: channel.send(message) print( - "Bytes handled :", - total_bytes_handled, - " Time : ", - datetime.datetime.now() - start_time, + "Bytes handled :", + total_bytes_handled, + " Time : ", + datetime.datetime.now() - start_time, ) @@ -86,12 +86,12 @@ class AudioStreamTrack(MediaStreamTrack): audio_buffer.write(frame) if local_frames := audio_buffer.read_many(256 * 960, partial=False): whisper_result = run_in_executor( - get_transcription, local_frames, executor=executor + get_transcription, local_frames, executor=executor ) whisper_result.add_done_callback( - lambda f: channel_send(data_channel, str(whisper_result.result())) - if (f.result()) - else None + lambda f: channel_send(data_channel, str(whisper_result.result())) + if (f.result()) + else None ) return frame @@ -140,10 +140,10 @@ async def offer(request): answer = await pc.createAnswer() await pc.setLocalDescription(answer) return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), + content_type="application/json", + text=json.dumps( + { "sdp": pc.localDescription.sdp, "type": pc.localDescription.type } + ), ) diff --git a/server_multithreaded.py b/server_multithreaded.py index 9bb24031..5b1baf88 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -1,5 +1,5 @@ import asyncio -import configparser +from utils.run_utils import config import datetime import io import json @@ -8,9 +8,9 @@ import threading import uuid import wave from concurrent.futures import ThreadPoolExecutor -from aiohttp import web import jax.numpy as jnp +from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import (MediaRelay) from av import AudioFifo @@ -18,13 +18,10 @@ from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline from utils.log_utils import logger -from utils.server_utils import Mutex +from utils.run_utils import Mutex ROOT = os.path.dirname(__file__) -config = configparser.ConfigParser() -config.read('config.ini') - WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] pcs = set() relay = MediaRelay() @@ -91,10 +88,10 @@ def get_transcription(): wf.close() whisper_result = pipeline(out_file.getvalue()) - item = {'text': whisper_result["text"], - 'start_time': str(frames[0].time), - 'time': str(datetime.datetime.now()) - } + item = { 'text': whisper_result["text"], + 'start_time': str(frames[0].time), + 'time': str(datetime.datetime.now()) + } sorted_message_queue[frames[0].time] = str(item) start_messaging_thread() except Exception as e: @@ -177,10 +174,10 @@ async def offer(request): answer = await pc.createAnswer() await pc.setLocalDescription(answer) return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), + content_type="application/json", + text=json.dumps( + { "sdp": pc.localDescription.sdp, "type": pc.localDescription.type } + ), ) @@ -196,5 +193,5 @@ if __name__ == "__main__": start_transcription_thread(6) app.router.add_post("/offer", offer) web.run_app( - app, access_log=None, host="127.0.0.1", port=1250 + app, access_log=None, host="127.0.0.1", port=1250 ) diff --git a/stream_client.py b/stream_client.py index 82f38e95..bd0eb159 100644 --- a/stream_client.py +++ b/stream_client.py @@ -1,6 +1,6 @@ import ast import asyncio -import configparser +from utils.run_utils import config import time import uuid @@ -12,12 +12,10 @@ from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) from utils.log_utils import logger -from utils.server_utils import Mutex +from utils.run_utils import Mutex file_lock = Mutex(open("test_sm_6.txt", "a")) -config = configparser.ConfigParser() -config.read('config.ini') class StreamClient: @@ -42,7 +40,7 @@ class StreamClient: self.time_start = None self.queue = asyncio.Queue() self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), - format='avfoundation', options={'channels': '2'}) + format='avfoundation', options={ 'channels': '2' }) def stop(self): self.loop.run_until_complete(self.signaling.close()) @@ -127,8 +125,8 @@ class StreamClient: await pc.setLocalDescription(await pc.createOffer()) sdp = { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type } @stamina.retry(on=httpx.HTTPError, attempts=5) diff --git a/utils/file_utils.py b/utils/file_utils.py index d9fcc08f..504f12c5 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -1,14 +1,12 @@ import configparser +import sys import boto3 import botocore - +from run_utils import config from log_utils import logger -config = configparser.ConfigParser() -config.read('config.ini') - -BUCKET_NAME = 'reflector-bucket' +BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"] s3 = boto3.client('s3', aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], @@ -18,8 +16,8 @@ s3 = boto3.client('s3', def upload_files(files_to_upload): """ Upload a list of files to the configured S3 bucket - :param files_to_upload: - :return: + :param files_to_upload: List of files to upload + :return: None """ for KEY in files_to_upload: logger.info("Uploading file " + KEY) @@ -32,8 +30,8 @@ def upload_files(files_to_upload): def download_files(files_to_download): """ Download a list of files from the configured S3 bucket - :param files_to_download: - :return: + :param files_to_download: List of files to download + :return: None """ for KEY in files_to_download: logger.info("Downloading file " + KEY) @@ -47,8 +45,6 @@ def download_files(files_to_download): if __name__ == "__main__": - import sys - if sys.argv[1] == "download": download_files([sys.argv[2]]) elif sys.argv[1] == "upload": diff --git a/utils/log_utils.py b/utils/log_utils.py index 3b874363..0cdb30f4 100644 --- a/utils/log_utils.py +++ b/utils/log_utils.py @@ -6,6 +6,10 @@ class SingletonLogger: @staticmethod def get_logger(): + """ + Create or return the singleton instance for the SingletonLogger class + :return: SingletonLogger instance + """ if not SingletonLogger.__instance: SingletonLogger.__instance = logger return SingletonLogger.__instance diff --git a/utils/run_utils.py b/utils/run_utils.py new file mode 100644 index 00000000..0ccd6942 --- /dev/null +++ b/utils/run_utils.py @@ -0,0 +1,66 @@ +import asyncio +import configparser +import contextlib +from functools import partial +from threading import Lock +from typing import ContextManager, Generic, TypeVar + + +class ConfigParser: + __config = configparser.ConfigParser() + + def __init__(self, config_file='../config.ini'): + self.__config.read(config_file) + + @staticmethod + def get_config(): + return ConfigParser.__config + + +config = ConfigParser.get_config() + + +def run_in_executor(func, *args, executor=None, **kwargs): + """ + Run the function in an executor, unblocking the main loop + :param func: Function to be run in executor + :param args: function parameters + :param executor: executor instance [Thread | Process] + :param kwargs: Additional parameters + :return: Future of function result upon completion + """ + callback = partial(func, *args, **kwargs) + loop = asyncio.get_event_loop() + return asyncio.get_event_loop().run_in_executor(executor, callback) + + +# Genetic type template +T = TypeVar("T") + + +class Mutex(Generic[T]): + """ + Mutex class to implement lock/release of a shared + protected variable + """ + + def __init__(self, value: T): + """ + Create an instance of Mutex wrapper for the given resource + :param value: Shared resources to be thread protected + """ + self.__value = value + self.__lock = Lock() + + @contextlib.contextmanager + def lock(self) -> ContextManager[T]: + """ + Lock the resource with a mutex to be used within a context block + The lock is automatically released on context exit + :return: Shared resource + """ + self.__lock.acquire() + try: + yield self.__value + finally: + self.__lock.release() diff --git a/utils/server_utils.py b/utils/server_utils.py deleted file mode 100644 index 5236a67d..00000000 --- a/utils/server_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import asyncio -import contextlib -from functools import partial -from threading import Lock -from typing import ContextManager, Generic, TypeVar - - -def run_in_executor(func, *args, executor=None, **kwargs): - callback = partial(func, *args, **kwargs) - loop = asyncio.get_event_loop() - return asyncio.get_event_loop().run_in_executor(executor, callback) - - -T = TypeVar("T") - - -class Mutex(Generic[T]): - def __init__(self, value: T): - self.__value = value - self.__lock = Lock() - - @contextlib.contextmanager - def lock(self) -> ContextManager[T]: - self.__lock.acquire() - try: - yield self.__value - finally: - self.__lock.release() diff --git a/utils/text_utilities.py b/utils/text_utilities.py index d67caf66..4fc292bb 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -6,14 +6,12 @@ from nltk.corpus import stopwords from nltk.tokenize import word_tokenize from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity -from transformers import BartTokenizer, BartForConditionalGeneration - +from transformers import BartForConditionalGeneration, BartTokenizer +from run_utils import config from log_utils import logger nltk.download('punkt', quiet=True) -config = configparser.ConfigParser() -config.read('config.ini') def preprocess_sentence(sentence): @@ -74,7 +72,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): for sent in nonduplicate_sentences: temp_result = "" - seen = {} + seen = { } words = nltk.word_tokenize(sent) n_gram_filter = 3 for i in range(len(words)): diff --git a/utils/viz_utilities.py b/utils/viz_utilities.py index e3e19a5d..77aa556f 100644 --- a/utils/viz_utilities.py +++ b/utils/viz_utilities.py @@ -1,6 +1,5 @@ import ast import collections -import configparser import os import pickle from pathlib import Path @@ -10,10 +9,7 @@ import pandas as pd import scattertext as st import spacy from nltk.corpus import stopwords -from wordcloud import WordCloud, STOPWORDS - -config = configparser.ConfigParser() -config.read('config.ini') +from wordcloud import STOPWORDS, WordCloud en = spacy.load('en_core_web_md') spacy_stopwords = en.Defaults.stop_words @@ -92,11 +88,11 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): # create df for processing df = pd.DataFrame.from_dict(res["chunks"]) - covered_items = {} + covered_items = { } # ts: timestamp # Map each timestamped chunk with top1 and top2 matched agenda - ts_to_topic_mapping_top_1 = {} - ts_to_topic_mapping_top_2 = {} + ts_to_topic_mapping_top_1 = { } + ts_to_topic_mapping_top_2 = { } # Also create a mapping of the different timestamps in which each topic was covered topic_to_ts_mapping_top_1 = collections.defaultdict(list) @@ -189,16 +185,16 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): # Scatter plot of topics df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) corpus = st.CorpusFromParsedDocuments( - df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' + df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) html = st.produce_scattertext_explorer( - corpus, - category=cat_1, - category_name=cat_1_name, - not_category_name=cat_2_name, - minimum_term_frequency=0, pmi_threshold_coefficient=0, - width_in_pixels=1000, - transform=st.Scalers.dense_rank + corpus, + category=cat_1, + category_name=cat_1_name, + not_category_name=cat_2_name, + minimum_term_frequency=0, pmi_threshold_coefficient=0, + width_in_pixels=1000, + transform=st.Scalers.dense_rank ) if real_time: open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) diff --git a/whisjax.py b/whisjax.py index ebfe1056..8f6c7239 100644 --- a/whisjax.py +++ b/whisjax.py @@ -20,18 +20,15 @@ import nltk import yt_dlp as youtube_dl from whisper_jax import FlaxWhisperPipline -from utils.file_utils import upload_files, download_files +from utils.file_utils import download_files, upload_files from utils.log_utils import logger -from utils.text_utilities import summarize, post_process_transcription -from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz +from utils.run_utils import config +from utils.text_utilities import post_process_transcription, summarize +from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) -# Configurations can be found in config.ini. Set them properly before executing -config = configparser.ConfigParser() -config.read('config.ini') - WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] NOW = datetime.now() @@ -42,8 +39,8 @@ def init_argparse() -> argparse.ArgumentParser: :return: parser object """ parser = argparse.ArgumentParser( - usage="%(prog)s [OPTIONS] ", - description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT." + usage="%(prog)s [OPTIONS] ", + description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT." ) parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str, @@ -74,13 +71,13 @@ def main(): # Create options for the download ydl_opts = { - 'format': 'bestaudio/best', - 'postprocessors': [{ - 'key': 'FFmpegExtractAudio', - 'preferredcodec': 'mp3', - 'preferredquality': '192', - }], - 'outtmpl': 'audio', # Specify the output file path and name + 'format': 'bestaudio/best', + 'postprocessors': [{ + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'mp3', + 'preferredquality': '192', + }], + 'outtmpl': 'audio', # Specify the output file path and name } # Download the audio diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 60b06a8c..48deef2a 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -13,11 +13,10 @@ from whisper_jax import FlaxWhisperPipline from utils.file_utils import upload_files from utils.log_utils import logger -from utils.text_utilities import summarize, post_process_transcription -from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz +from utils.run_utils import config +from utils.text_utilities import post_process_transcription, summarize +from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud -config = configparser.ConfigParser() -config.read('config.ini') WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] @@ -37,12 +36,12 @@ def main(): AUDIO_DEVICE_ID = i audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) stream = p.open( - format=FORMAT, - channels=CHANNELS, - rate=RATE, - input=True, - frames_per_buffer=FRAMES_PER_BUFFER, - input_device_index=int(audio_devices['index']) + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=FRAMES_PER_BUFFER, + input_device_index=int(audio_devices['index']) ) pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], @@ -60,7 +59,7 @@ def main(): global proceed proceed = False - transcript_with_timestamp = {"text": "", "chunks": []} + transcript_with_timestamp = { "text": "", "chunks": [] } last_transcribed_time = 0.0 listener = keyboard.Listener(on_press=on_press) @@ -90,10 +89,10 @@ def main(): if end is None: end = start + 15.0 duration = end - start - item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), - 'text': whisper_result['text'], - 'stats': (str(end_time - start_time), str(duration)) - } + item = { 'timestamp': (last_transcribed_time, last_transcribed_time + duration), + 'text': whisper_result['text'], + 'stats': (str(end_time - start_time), str(duration)) + } last_transcribed_time = last_transcribed_time + duration transcript_with_timestamp["chunks"].append(item) transcription += whisper_result['text'] From 8deeff588f05fe6d949e851aa380f9add2c488f9 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 12:41:26 +0530 Subject: [PATCH 4/9] clean up --- constants.py | 0 server_executor_cleaned.py | 17 +---------------- server_multithreaded.py | 16 +++------------- stream_client.py | 3 --- 4 files changed, 4 insertions(+), 32 deletions(-) create mode 100644 constants.py diff --git a/constants.py b/constants.py new file mode 100644 index 00000000..e69de29b diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index 1a40b4af..487bf0f8 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -14,15 +14,11 @@ from aiortc.contrib.media import MediaRelay from av import AudioFifo from loguru import logger from whisper_jax import FlaxWhisperPipline - from utils.run_utils import run_in_executor -transcription = "" - pcs = set() relay = MediaRelay() data_channel = None -total_bytes_handled = 0 pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16) @@ -30,7 +26,6 @@ pipeline = FlaxWhisperPipline("openai/whisper-tiny", CHANNELS = 2 RATE = 48000 audio_buffer = AudioFifo() -start_time = datetime.datetime.now() executor = ThreadPoolExecutor() @@ -40,15 +35,8 @@ def channel_log(channel, t, message): def channel_send(channel, message): # channel_log(channel, ">", message) - global start_time if channel: channel.send(message) - print( - "Bytes handled :", - total_bytes_handled, - " Time : ", - datetime.datetime.now() - start_time, - ) def get_transcription(frames): @@ -61,8 +49,6 @@ def get_transcription(frames): for frame in frames: wf.writeframes(b"".join(frame.to_ndarray())) wf.close() - global total_bytes_handled - total_bytes_handled += sys.getsizeof(wf) whisper_result = pipeline(out_file.getvalue(), return_timestamps=True) with open("test_exec.txt", "a") as f: f.write(whisper_result["text"]) @@ -111,10 +97,9 @@ async def offer(request): @pc.on("datachannel") def on_datachannel(channel): - global data_channel, start_time + global data_channel data_channel = channel channel_log(channel, "-", "created by remote party") - start_time = datetime.datetime.now() @channel.on("message") def on_message(message): diff --git a/server_multithreaded.py b/server_multithreaded.py index 5b1baf88..bf0c371d 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -12,37 +12,27 @@ from concurrent.futures import ThreadPoolExecutor import jax.numpy as jnp from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription -from aiortc.contrib.media import (MediaRelay) +from aiortc.contrib.media import MediaRelay from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline - from utils.log_utils import logger from utils.run_utils import Mutex -ROOT = os.path.dirname(__file__) - -WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] pcs = set() relay = MediaRelay() data_channel = None sorted_message_queue = SortedDict() - CHANNELS = 2 RATE = 44100 CHUNK_SIZE = 256 - -audio_buffer = AudioFifo() pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, batch_size=16) - -transcription = "" start_time = datetime.datetime.now() -total_bytes_handled = 0 - executor = ThreadPoolExecutor() - +audio_buffer = AudioFifo() frame_lock = Mutex(audio_buffer) diff --git a/stream_client.py b/stream_client.py index bd0eb159..22177970 100644 --- a/stream_client.py +++ b/stream_client.py @@ -17,7 +17,6 @@ from utils.run_utils import Mutex file_lock = Mutex(open("test_sm_6.txt", "a")) - class StreamClient: def __init__( self, @@ -46,7 +45,6 @@ class StreamClient: self.loop.run_until_complete(self.signaling.close()) self.loop.run_until_complete(self.pc.close()) # self.loop.close() - print("ended") def create_local_tracks(self, play_from): if play_from: @@ -55,7 +53,6 @@ class StreamClient: else: if self.relay is None: self.relay = MediaRelay() - print("Created local track from microphone stream") return self.relay.subscribe(self.player.audio), None def channel_log(self, channel, t, message): From 54de3683b39a9573c12948e94848a8007ecd40cd Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 12:52:58 +0530 Subject: [PATCH 5/9] fix imports --- constants.py | 0 server_executor_cleaned.py | 3 +-- server_multithreaded.py | 5 ++--- stream_client.py | 3 +-- utils/file_utils.py | 4 ++-- utils/text_utilities.py | 6 ++---- whisjax.py | 1 - whisjax_realtime.py | 2 -- 8 files changed, 8 insertions(+), 16 deletions(-) delete mode 100644 constants.py diff --git a/constants.py b/constants.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index 487bf0f8..b9334e52 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -1,8 +1,6 @@ import asyncio -import datetime import io import json -import sys import uuid import wave from concurrent.futures import ThreadPoolExecutor @@ -14,6 +12,7 @@ from aiortc.contrib.media import MediaRelay from av import AudioFifo from loguru import logger from whisper_jax import FlaxWhisperPipline + from utils.run_utils import run_in_executor pcs = set() diff --git a/server_multithreaded.py b/server_multithreaded.py index bf0c371d..b62def09 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -1,9 +1,7 @@ import asyncio -from utils.run_utils import config import datetime import io import json -import os import threading import uuid import wave @@ -16,8 +14,9 @@ from aiortc.contrib.media import MediaRelay from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline + from utils.log_utils import logger -from utils.run_utils import Mutex +from utils.run_utils import config, Mutex WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] pcs = set() diff --git a/stream_client.py b/stream_client.py index 22177970..a6b879e2 100644 --- a/stream_client.py +++ b/stream_client.py @@ -1,6 +1,5 @@ import ast import asyncio -from utils.run_utils import config import time import uuid @@ -12,7 +11,7 @@ from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) from utils.log_utils import logger -from utils.run_utils import Mutex +from utils.run_utils import config, Mutex file_lock = Mutex(open("test_sm_6.txt", "a")) diff --git a/utils/file_utils.py b/utils/file_utils.py index 504f12c5..2c14f00f 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -1,10 +1,10 @@ -import configparser import sys import boto3 import botocore -from run_utils import config + from log_utils import logger +from run_utils import config BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"] diff --git a/utils/text_utilities.py b/utils/text_utilities.py index 4fc292bb..f41cd800 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -1,5 +1,3 @@ -import configparser - import nltk import torch from nltk.corpus import stopwords @@ -7,13 +5,13 @@ from nltk.tokenize import word_tokenize from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import BartForConditionalGeneration, BartTokenizer -from run_utils import config + from log_utils import logger +from run_utils import config nltk.download('punkt', quiet=True) - def preprocess_sentence(sentence): stop_words = set(stopwords.words('english')) tokens = word_tokenize(sentence.lower()) diff --git a/whisjax.py b/whisjax.py index 8f6c7239..8946953f 100644 --- a/whisjax.py +++ b/whisjax.py @@ -5,7 +5,6 @@ # summarize podcast.mp3 summary.txt import argparse -import configparser import os import re import subprocess diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 48deef2a..d8623ddc 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import configparser import time import wave from datetime import datetime @@ -17,7 +16,6 @@ from utils.run_utils import config from utils.text_utilities import post_process_transcription, summarize from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud - WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] FRAMES_PER_BUFFER = 8000 From 88af11213109d143908f78f4d80e8ccaaa120c71 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 13:45:28 +0530 Subject: [PATCH 6/9] reformatting --- client.py | 1 - server_executor_cleaned.py | 2 +- server_multithreaded.py | 31 ++++++++++++++++--------------- stream_client.py | 2 +- utils/text_utilities.py | 2 +- utils/viz_utilities.py | 6 +++--- whisjax_realtime.py | 10 +++++----- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/client.py b/client.py index d6393712..5cf8d47d 100644 --- a/client.py +++ b/client.py @@ -64,7 +64,6 @@ async def main(): ping_pong=args.ping_pong ) await sc.start() - print("Stream client started") async for msg in sc.get_reader(): print(msg) diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index b9334e52..ecac6d48 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -126,7 +126,7 @@ async def offer(request): return web.Response( content_type="application/json", text=json.dumps( - { "sdp": pc.localDescription.sdp, "type": pc.localDescription.type } + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} ), ) diff --git a/server_multithreaded.py b/server_multithreaded.py index b62def09..7382a654 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -8,6 +8,7 @@ import wave from concurrent.futures import ThreadPoolExecutor import jax.numpy as jnp +import requests from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay @@ -77,10 +78,11 @@ def get_transcription(): wf.close() whisper_result = pipeline(out_file.getvalue()) - item = { 'text': whisper_result["text"], - 'start_time': str(frames[0].time), - 'time': str(datetime.datetime.now()) - } + item = { + 'text': whisper_result["text"], + 'start_time': str(frames[0].time), + 'time': str(datetime.datetime.now()) + } sorted_message_queue[frames[0].time] = str(item) start_messaging_thread() except Exception as e: @@ -89,7 +91,7 @@ def get_transcription(): class AudioStreamTrack(MediaStreamTrack): """ - A video stream track that transforms frames from an another track. + An audio stream track to send audio frames. """ kind = "audio" @@ -109,15 +111,13 @@ def start_messaging_thread(): message_thread.start() -def start_transcription_thread(max_threads): - t_threads = [] +def start_transcription_thread(max_threads: int): for i in range(max_threads): t_thread = threading.Thread(target=get_transcription, args=(i,)) - t_threads.append(t_thread) t_thread.start() -async def offer(request): +async def offer(request: requests.Request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -125,7 +125,7 @@ async def offer(request): pc_id = "PeerConnection(%s)" % uuid.uuid4() pcs.add(pc) - def log_info(msg, *args): + def log_info(msg: str, *args): logger.info(pc_id + " " + msg, *args) log_info("Created for %s", request.remote) @@ -138,7 +138,7 @@ async def offer(request): start_time = datetime.datetime.now() @channel.on("message") - def on_message(message): + def on_message(message: str): channel_log(channel, "<", message) if isinstance(message, str) and message.startswith("ping"): # reply @@ -164,13 +164,14 @@ async def offer(request): await pc.setLocalDescription(answer) return web.Response( content_type="application/json", - text=json.dumps( - { "sdp": pc.localDescription.sdp, "type": pc.localDescription.type } - ), + text=json.dumps({ + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type + }), ) -async def on_shutdown(app): +async def on_shutdown(app: web.Application): coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() diff --git a/stream_client.py b/stream_client.py index a6b879e2..d7791e5c 100644 --- a/stream_client.py +++ b/stream_client.py @@ -38,7 +38,7 @@ class StreamClient: self.time_start = None self.queue = asyncio.Queue() self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), - format='avfoundation', options={ 'channels': '2' }) + format='avfoundation', options={'channels': '2'}) def stop(self): self.loop.run_until_complete(self.signaling.close()) diff --git a/utils/text_utilities.py b/utils/text_utilities.py index f41cd800..900f9194 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -70,7 +70,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): for sent in nonduplicate_sentences: temp_result = "" - seen = { } + seen = {} words = nltk.word_tokenize(sent) n_gram_filter = 3 for i in range(len(words)): diff --git a/utils/viz_utilities.py b/utils/viz_utilities.py index 77aa556f..fa09144e 100644 --- a/utils/viz_utilities.py +++ b/utils/viz_utilities.py @@ -88,11 +88,11 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): # create df for processing df = pd.DataFrame.from_dict(res["chunks"]) - covered_items = { } + covered_items = {} # ts: timestamp # Map each timestamped chunk with top1 and top2 matched agenda - ts_to_topic_mapping_top_1 = { } - ts_to_topic_mapping_top_2 = { } + ts_to_topic_mapping_top_1 = {} + ts_to_topic_mapping_top_2 = {} # Also create a mapping of the different timestamps in which each topic was covered topic_to_ts_mapping_top_1 = collections.defaultdict(list) diff --git a/whisjax_realtime.py b/whisjax_realtime.py index d8623ddc..68dc472a 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -57,7 +57,7 @@ def main(): global proceed proceed = False - transcript_with_timestamp = { "text": "", "chunks": [] } + transcript_with_timestamp = {"text": "", "chunks": []} last_transcribed_time = 0.0 listener = keyboard.Listener(on_press=on_press) @@ -87,10 +87,10 @@ def main(): if end is None: end = start + 15.0 duration = end - start - item = { 'timestamp': (last_transcribed_time, last_transcribed_time + duration), - 'text': whisper_result['text'], - 'stats': (str(end_time - start_time), str(duration)) - } + item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), + 'text': whisper_result['text'], + 'stats': (str(end_time - start_time), str(duration)) + } last_transcribed_time = last_transcribed_time + duration transcript_with_timestamp["chunks"].append(item) transcription += whisper_result['text'] From d962ff17126c3af4e0c509fe57d22d77a033a6be Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 14:06:20 +0530 Subject: [PATCH 7/9] flake8 warnings fix --- client.py | 2 +- server_executor_cleaned.py | 6 ++-- stream_client.py | 11 ++++--- utils/__init__.py | 0 utils/log_utils.py | 4 +-- utils/run_utils.py | 2 +- utils/text_utilities.py | 61 ++++++++++++++++++++++++-------------- utils/viz_utilities.py | 41 +++++++++++++++---------- whisjax.py | 43 +++++++++++++++++---------- whisjax_realtime.py | 22 +++++++++----- 10 files changed, 122 insertions(+), 70 deletions(-) create mode 100644 utils/__init__.py diff --git a/client.py b/client.py index 5cf8d47d..b0fa46a5 100644 --- a/client.py +++ b/client.py @@ -47,7 +47,7 @@ async def main(): logger.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) - logger.info(f"Flushing metrics") + logger.info(f'{"Flushing metrics"}') loop.stop() signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index ecac6d48..e0fb4cc3 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -74,7 +74,8 @@ class AudioStreamTrack(MediaStreamTrack): get_transcription, local_frames, executor=executor ) whisper_result.add_done_callback( - lambda f: channel_send(data_channel, str(whisper_result.result())) + lambda f: channel_send(data_channel, + str(whisper_result.result())) if (f.result()) else None ) @@ -126,7 +127,8 @@ async def offer(request): return web.Response( content_type="application/json", text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + {"sdp": pc.localDescription.sdp, + "type": pc.localDescription.type} ), ) diff --git a/stream_client.py b/stream_client.py index d7791e5c..628ee69e 100644 --- a/stream_client.py +++ b/stream_client.py @@ -37,8 +37,10 @@ class StreamClient: self.pcs = set() self.time_start = None self.queue = asyncio.Queue() - self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), - format='avfoundation', options={'channels': '2'}) + self.player = MediaPlayer( + ':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), + format='avfoundation', + options={'channels': '2'}) def stop(self): self.loop.run_until_complete(self.signaling.close()) @@ -115,7 +117,8 @@ class StreamClient: self.channel_log(channel, "<", message) if isinstance(message, str) and message.startswith("pong"): - elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000 + elapsed_ms = (self.current_stamp() - int(message[5:]))\ + / 1000 print(" RTT %.2f ms" % elapsed_ms) await pc.setLocalDescription(await pc.createOffer()) @@ -135,7 +138,7 @@ class StreamClient: answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) await pc.setRemoteDescription(answer) - self.reader = self.worker(f"worker", self.queue) + self.reader = self.worker(f'{"worker"}', self.queue) def get_reader(self): return self.reader diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/log_utils.py b/utils/log_utils.py index 0cdb30f4..f665f5da 100644 --- a/utils/log_utils.py +++ b/utils/log_utils.py @@ -1,4 +1,4 @@ -from loguru import logger +import loguru class SingletonLogger: @@ -11,7 +11,7 @@ class SingletonLogger: :return: SingletonLogger instance """ if not SingletonLogger.__instance: - SingletonLogger.__instance = logger + SingletonLogger.__instance = loguru.logger return SingletonLogger.__instance diff --git a/utils/run_utils.py b/utils/run_utils.py index 0ccd6942..dca09c87 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -31,7 +31,7 @@ def run_in_executor(func, *args, executor=None, **kwargs): """ callback = partial(func, *args, **kwargs) loop = asyncio.get_event_loop() - return asyncio.get_event_loop().run_in_executor(executor, callback) + return loop.run_in_executor(executor, callback) # Genetic type template diff --git a/utils/text_utilities.py b/utils/text_utilities.py index 900f9194..519990cb 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -15,7 +15,8 @@ nltk.download('punkt', quiet=True) def preprocess_sentence(sentence): stop_words = set(stopwords.words('english')) tokens = word_tokenize(sentence.lower()) - tokens = [token for token in tokens if token.isalnum() and token not in stop_words] + tokens = [token for token in tokens + if token.isalnum() and token not in stop_words] return ' '.join(tokens) @@ -49,12 +50,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7): sentence1 = preprocess_sentence(sentences[i]) sentence2 = preprocess_sentence(sentences[j]) if len(sentence1) != 0 and len(sentence2) != 0: - similarity = compute_similarity(sentence1, sentence2) + similarity = compute_similarity(sentence1, + sentence2) if similarity >= threshold: removed_indices.add(max(i, j)) - filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] + filtered_sentences = [sentences[i] for i in range(num_sentences) + if i not in removed_indices] return filtered_sentences @@ -74,11 +77,13 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): words = nltk.word_tokenize(sent) n_gram_filter = 3 for i in range(len(words)): - if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[ - i + 1:i + n_gram_filter + 2]: + if str(words[i:i + n_gram_filter]) in seen and \ + seen[str(words[i:i + n_gram_filter])] == \ + words[i + 1:i + n_gram_filter + 2]: pass else: - seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2] + seen[str(words[i:i + n_gram_filter])] = \ + words[i + 1:i + n_gram_filter + 2] temp_result += words[i] temp_result += " " chunk_sentences.append(temp_result) @@ -88,9 +93,12 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): def post_process_transcription(whisper_result): transcript_text = "" for chunk in whisper_result["chunks"]: - nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) - chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences) - similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) + nonduplicate_sentences = \ + remove_outright_duplicate_sentences_from_chunk(chunk) + chunk_sentences = \ + remove_whisper_repetitive_hallucination(nonduplicate_sentences) + similarity_matched_sentences = \ + remove_almost_alike_sentences(chunk_sentences) chunk["text"] = " ".join(similarity_matched_sentences) transcript_text += chunk["text"] whisper_result["text"] = transcript_text @@ -111,18 +119,23 @@ def summarize_chunks(chunks, tokenizer, model): input_ids = tokenizer.encode(c, return_tensors='pt') input_ids = input_ids.to(device) with torch.no_grad(): - summary_ids = model.generate(input_ids, - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) + summary_ids = \ + model.generate(input_ids, + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), + length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), + early_stopping=True) + summary = tokenizer.decode(summary_ids[0], + skip_special_tokens=True) summaries.append(summary) return summaries -def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): +def chunk_text(text, + max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): """ Split text into smaller chunks. - :param txt: Text to be chunked + :param text: Text to be chunked :param max_chunk_length: length of chunk :return: chunked texts """ @@ -140,7 +153,8 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"]) def summarize(transcript_text, timestamp, - real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): + real_time=False, + summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") summary_model = config["DEFAULT"]["SUMMARY_MODEL"] if not summary_model: @@ -157,9 +171,11 @@ def summarize(transcript_text, timestamp, output_filename = "real_time_" + output_filename if summarize_using_chunks != "YES": - inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest', - max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), - return_tensors='pt') + inputs = tokenizer.\ + batch_encode_plus([transcript_text], truncation=True, + padding='longest', + max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), + return_tensors='pt') inputs = inputs.to(device) with torch.no_grad(): @@ -167,8 +183,8 @@ def summarize(transcript_text, timestamp, num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for - summary in summaries] + decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for summary in summaries] summary = " ".join(decoded_summaries) with open(output_filename, 'w') as f: f.write(summary.strip() + "\n") @@ -176,7 +192,8 @@ def summarize(transcript_text, timestamp, logger.info("Breaking transcript into smaller chunks") chunks = chunk_text(transcript_text) - logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable + logger.info(f"Transcript broken into {len(chunks)} " + f"chunks of at most 500 words") logger.info(f"Writing summary text to: {output_filename}") with open(output_filename, 'w') as f: diff --git a/utils/viz_utilities.py b/utils/viz_utilities.py index fa09144e..e1ab88c9 100644 --- a/utils/viz_utilities.py +++ b/utils/viz_utilities.py @@ -2,7 +2,6 @@ import ast import collections import os import pickle -from pathlib import Path import matplotlib.pyplot as plt import pandas as pd @@ -14,7 +13,8 @@ from wordcloud import STOPWORDS, WordCloud en = spacy.load('en_core_web_md') spacy_stopwords = en.Defaults.stop_words -STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords)) +STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).\ + union(set(spacy_stopwords)) def create_wordcloud(timestamp, real_time=False): @@ -24,7 +24,8 @@ def create_wordcloud(timestamp, real_time=False): """ filename = "transcript" if real_time: - filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename = "real_time_" + filename + "_" +\ + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" else: filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" @@ -46,7 +47,8 @@ def create_wordcloud(timestamp, real_time=False): wordcloud_name = "wordcloud" if real_time: - wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + wordcloud_name = "real_time_" + wordcloud_name + "_" +\ + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" else: wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" @@ -66,7 +68,6 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): agenda_topics = [] agenda = [] # Load the agenda - path = Path(__file__) with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f: for line in f.readlines(): if line.strip(): @@ -76,9 +77,11 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): # Load the transcription with timestamp filename = "" if real_time: - filename = "real_time_transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename = "real_time_transcript_with_timestamp_" +\ + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" else: - filename = "transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename = "transcript_with_timestamp_" +\ + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" with open(filename) as f: transcription_timestamp_text = f.read() @@ -94,7 +97,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): ts_to_topic_mapping_top_1 = {} ts_to_topic_mapping_top_2 = {} - # Also create a mapping of the different timestamps in which each topic was covered + # Also create a mapping of the different timestamps + # in which each topic was covered topic_to_ts_mapping_top_1 = collections.defaultdict(list) topic_to_ts_mapping_top_2 = collections.defaultdict(list) @@ -105,7 +109,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): topic_similarities = [] for item in range(len(agenda)): item_doc = nlp(agenda[item]) - # if not doc_transcription or not all(token.has_vector for token in doc_transcription): + # if not doc_transcription or not all + # (token.has_vector for token in doc_transcription): if not doc_transcription: continue similarity = doc_transcription.similarity(item_doc) @@ -129,8 +134,10 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): :param record: :return: """ - record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]] - record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]] + record["ts_to_topic_mapping_top_1"] = \ + ts_to_topic_mapping_top_1[record["timestamp"]] + record["ts_to_topic_mapping_top_2"] = \ + ts_to_topic_mapping_top_2[record["timestamp"]] return record df = df.apply(create_new_columns, axis=1) @@ -151,7 +158,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): # Save df, mappings for further experimentation df_name = "df" if real_time: - df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + df_name = "real_time_" + df_name + "_" +\ + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" else: df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" df.to_pickle(df_name) @@ -161,7 +169,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): mappings_name = "mappings" if real_time: - mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + mappings_name = "real_time_" + mappings_name + "_" +\ + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" else: mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" pickle.dump(my_mappings, open(mappings_name, "wb")) @@ -197,6 +206,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): transform=st.Scalers.dense_rank ) if real_time: - open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) + open('./artefacts/real_time_scatter_' + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) else: - open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) + open('./artefacts/scatter_' + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) diff --git a/whisjax.py b/whisjax.py index 8946953f..9e8ce4cf 100644 --- a/whisjax.py +++ b/whisjax.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt +# summarize https://www.youtube.com/watch?v=imzTxoEDH_g # summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt # summarize podcast.mp3 summary.txt @@ -14,7 +14,6 @@ from urllib.parse import urlparse import jax.numpy as jnp import moviepy.editor -import moviepy.editor import nltk import yt_dlp as youtube_dl from whisper_jax import FlaxWhisperPipline @@ -39,11 +38,16 @@ def init_argparse() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser( usage="%(prog)s [OPTIONS] ", - description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT." + description="Creates a transcript of a video or audio file, then" + " summarizes it using ChatGPT." ) - parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str, - default="english", choices=['english', 'spanish', 'french', 'german', 'romanian']) + parser.add_argument("-l", "--language", + help="Language that the summary should be written in", + type=str, + default="english", + choices=['english', 'spanish', 'french', 'german', + 'romanian']) parser.add_argument("location") return parser @@ -61,10 +65,12 @@ def main(): media_file = "" if url.scheme == 'http' or url.scheme == 'https': - # Check if we're being asked to retreive a YouTube URL, which is handled - # diffrently, as we'll use a secondary site to download the video first. + # Check if we're being asked to retreive a YouTube URL, which is + # handled differently, as we'll use a secondary site to download + # the video first. if re.search('youtube.com', url.netloc, re.IGNORECASE): - # Download the lowest resolution YouTube video (since we're just interested in the audio). + # Download the lowest resolution YouTube video + # (since we're just interested in the audio). # It will be saved to the current directory. logger.info("Downloading YouTube video at url: " + args.location) @@ -76,7 +82,7 @@ def main(): 'preferredcodec': 'mp3', 'preferredquality': '192', }], - 'outtmpl': 'audio', # Specify the output file path and name + 'outtmpl': 'audio', # Specify output file path and name } # Download the audio @@ -86,7 +92,8 @@ def main(): logger.info("Saved downloaded YouTube video to: " + media_file) else: - # XXX - Download file using urllib, check if file is audio/video using python-magic + # XXX - Download file using urllib, check if file is + # audio/video using python-magic logger.info(f"Downloading file at url: {args.location}") logger.info(" XXX - This method hasn't been implemented yet.") elif url.scheme == '': @@ -97,7 +104,7 @@ def main(): if media_file.endswith(".m4a"): subprocess.run(["ffmpeg", "-i", media_file, f"{media_file}.mp4"]) - input_file = f"{media_file}.mp4" + media_file = f"{media_file}.mp4" else: print("Unsupported URL scheme: " + url.scheme) quit() @@ -106,13 +113,15 @@ def main(): if not media_file.endswith(".mp3"): try: video = moviepy.editor.VideoFileClip(media_file) - audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name + audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", + delete=False).name video.audio.write_audiofile(audio_filename, logger=None) logger.info(f"Extracting audio to: {audio_filename}") # Handle audio only file - except: + except Exception: audio = moviepy.editor.AudioFileClip(media_file) - audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name + audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", + delete=False).name audio.write_audiofile(audio_filename, logger=None) else: audio_filename = media_file @@ -132,10 +141,12 @@ def main(): for chunk in whisper_result["chunks"]: transcript_text += chunk["text"] - with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file: + with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + + ".txt", "w") as transcript_file: transcript_file.write(transcript_text) - with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", + with open("./artefacts/transcript_with_timestamp_" + + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps: transcript_file_timestamps.write(str(whisper_result)) diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 68dc472a..63eab04d 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -30,7 +30,8 @@ def main(): p = pyaudio.PyAudio() AUDIO_DEVICE_ID = -1 for i in range(p.get_device_count()): - if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: + if p.get_device_info_by_index(i)["name"] == \ + config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: AUDIO_DEVICE_ID = i audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) stream = p.open( @@ -42,7 +43,8 @@ def main(): input_device_index=int(audio_devices['index']) ) - pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], + pipeline = FlaxWhisperPipline("openai/whisper-" + + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], dtype=jnp.float16, batch_size=16) @@ -69,7 +71,8 @@ def main(): frames = [] start_time = time.time() for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): - data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) + data = stream.read(FRAMES_PER_BUFFER, + exception_on_overflow=False) frames.append(data) end_time = time.time() @@ -87,7 +90,8 @@ def main(): if end is None: end = start + 15.0 duration = end - start - item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), + item = {'timestamp': (last_transcribed_time, + last_transcribed_time + duration), 'text': whisper_result['text'], 'stats': (str(end_time - start_time), str(duration)) } @@ -97,15 +101,19 @@ def main(): print(colored("", "yellow")) print(colored(whisper_result['text'], 'green')) - print(colored(" Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " + + print(colored(" Recorded duration: " + + str(end_time - start_time) + + " | Transcribed duration: " + str(duration), "yellow")) except Exception as e: print(e) finally: - with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + with open("real_time_transcript_" + + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: f.write(transcription) - with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + with open("real_time_transcript_with_timestamp_" + + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: transcript_with_timestamp["text"] = transcription f.write(str(transcript_with_timestamp)) From 71eb277fd7ab43b833b7c76a4b5d94f1a3ed1e30 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 18:47:21 +0530 Subject: [PATCH 8/9] refactor --- .gitignore | 4 +-- config.ini | 22 --------------- requirements.txt | 1 - scripts/clear_artefacts.sh | 15 ++++++++-- server_multithreaded.py | 9 +++--- utils/file_utils.py | 4 +-- utils/run_utils.py | 14 +++++----- utils/test.py | 0 utils/text_utilities.py | 6 ++-- utils/viz_utilities.py | 57 +++++++++++++++++++------------------- whisjax.py | 15 +++++----- 11 files changed, 67 insertions(+), 80 deletions(-) delete mode 100644 config.ini create mode 100644 utils/test.py diff --git a/.gitignore b/.gitignore index fd3e8b20..c08eb9a3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,9 +160,6 @@ cython_debug/ #.idea/ *.mp4 -summary.txt -transcript.txt -transcript_timestamps.txt *.html *.pkl transcript_*.txt @@ -176,3 +173,4 @@ test_samples/ .DS_Store/ .DS_Store .vscode/ +artefacts/ diff --git a/config.ini b/config.ini deleted file mode 100644 index 0092129f..00000000 --- a/config.ini +++ /dev/null @@ -1,22 +0,0 @@ -[DEFAULT] -# Set exception rule for OpenMP error to allow duplicate lib initialization -KMP_DUPLICATE_LIB_OK = TRUE -# Export OpenAI API Key -OPENAI_APIKEY = -# Export Whisper Model Size -WHISPER_MODEL_SIZE = tiny -WHISPER_REAL_TIME_MODEL_SIZE = tiny -# AWS config -AWS_ACCESS_KEY = ***REMOVED*** -AWS_SECRET_KEY = ***REMOVED*** -BUCKET_NAME = 'reflector-bucket' -# Summarizer config -SUMMARY_MODEL = facebook/bart-large-cnn -INPUT_ENCODING_MAX_LENGTH = 1024 -MAX_LENGTH = 2048 -BEAM_SIZE = 6 -MAX_CHUNK_LENGTH = 1024 -SUMMARIZE_USING_CHUNKS = YES -# Audio device -BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME = aggregator -AV_FOUNDATION_DEVICE_ID = 2 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 23b8e38d..21fdd61a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,5 +56,4 @@ cached_property==1.5.2 stamina==23.1.0 httpx==0.24.1 sortedcontainers==2.4.0 -openai-whisper@ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8 https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz diff --git a/scripts/clear_artefacts.sh b/scripts/clear_artefacts.sh index c06c4c2c..d6e7722f 100755 --- a/scripts/clear_artefacts.sh +++ b/scripts/clear_artefacts.sh @@ -1,15 +1,24 @@ #!/bin/bash # Directory to search for Python files -directory="." +cwd=$(pwd) +last_component="${cwd##*/}" + +if [ "$last_component" = "reflector" ]; then + directory="./artefacts" +elif [ "$last_component" = "scripts" ]; then + directory="../artefacts" +fi # Pattern to match Python files (e.g., "*.py" for all .py files) -text_file_pattern="transcript_*.txt" +transcript_file_pattern="transcript_*.txt" +summary_file_pattern="summary_*.txt" pickle_file_pattern="*.pkl" html_file_pattern="*.html" png_file_pattern="wordcloud*.png" -find "$directory" -type f -name "$text_file_pattern" -delete +find "$directory" -type f -name "$transcript_file_pattern" -delete +find "$directory" -type f -name "$summary_file_pattern" -delete find "$directory" -type f -name "$pickle_file_pattern" -delete find "$directory" -type f -name "$html_file_pattern" -delete find "$directory" -type f -name "$png_file_pattern" -delete diff --git a/server_multithreaded.py b/server_multithreaded.py index 7382a654..2862fa36 100644 --- a/server_multithreaded.py +++ b/server_multithreaded.py @@ -65,6 +65,7 @@ def get_transcription(): transcribe = True if transcribe: + print("Transcribing..") try: sorted_message_queue[frames[0].time] = None out_file = io.BytesIO() @@ -113,7 +114,7 @@ def start_messaging_thread(): def start_transcription_thread(max_threads: int): for i in range(max_threads): - t_thread = threading.Thread(target=get_transcription, args=(i,)) + t_thread = threading.Thread(target=get_transcription) t_thread.start() @@ -128,7 +129,7 @@ async def offer(request: requests.Request): def log_info(msg: str, *args): logger.info(pc_id + " " + msg, *args) - log_info("Created for %s", request.remote) + log_info("Created for " + request.remote) @pc.on("datachannel") def on_datachannel(channel): @@ -146,14 +147,14 @@ async def offer(request: requests.Request): @pc.on("connectionstatechange") async def on_connectionstatechange(): - log_info("Connection state is %s", pc.connectionState) + log_info("Connection state is " + pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @pc.on("track") def on_track(track): - log_info("Track %s received", track.kind) + log_info("Track " + track.kind + " received") pc.addTrack(AudioStreamTrack(relay.subscribe(track))) # handle offer diff --git a/utils/file_utils.py b/utils/file_utils.py index 2c14f00f..cc9a9ded 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -3,8 +3,8 @@ import sys import boto3 import botocore -from log_utils import logger -from run_utils import config +from .log_utils import logger +from .run_utils import config BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"] diff --git a/utils/run_utils.py b/utils/run_utils.py index dca09c87..bb2b6348 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -6,18 +6,18 @@ from threading import Lock from typing import ContextManager, Generic, TypeVar -class ConfigParser: - __config = configparser.ConfigParser() - - def __init__(self, config_file='../config.ini'): - self.__config.read(config_file) +class ReflectorConfig: + __config = None @staticmethod def get_config(): - return ConfigParser.__config + if ReflectorConfig.__config is None: + ReflectorConfig.__config = configparser.ConfigParser() + ReflectorConfig.__config.read('utils/config.ini') + return ReflectorConfig.__config -config = ConfigParser.get_config() +config = ReflectorConfig.get_config() def run_in_executor(func, *args, executor=None, **kwargs): diff --git a/utils/test.py b/utils/test.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/text_utilities.py b/utils/text_utilities.py index 519990cb..ef15c7a3 100644 --- a/utils/text_utilities.py +++ b/utils/text_utilities.py @@ -6,8 +6,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import BartForConditionalGeneration, BartTokenizer -from log_utils import logger -from run_utils import config +from utils.log_utils import logger +from utils.run_utils import config nltk.download('punkt', quiet=True) @@ -186,7 +186,7 @@ def summarize(transcript_text, timestamp, decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for summary in summaries] summary = " ".join(decoded_summaries) - with open(output_filename, 'w') as f: + with open("./artefacts/" + output_filename, 'w') as f: f.write(summary.strip() + "\n") else: logger.info("Breaking transcript into smaller chunks") diff --git a/utils/viz_utilities.py b/utils/viz_utilities.py index e1ab88c9..93a9b56f 100644 --- a/utils/viz_utilities.py +++ b/utils/viz_utilities.py @@ -52,7 +52,7 @@ def create_wordcloud(timestamp, real_time=False): else: wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" - plt.savefig(wordcloud_name) + plt.savefig("./artefacts/" + wordcloud_name) def create_talk_diff_scatter_viz(timestamp, real_time=False): @@ -77,10 +77,10 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): # Load the transcription with timestamp filename = "" if real_time: - filename = "real_time_transcript_with_timestamp_" +\ + filename = "./artefacts/real_time_transcript_with_timestamp_" +\ timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" else: - filename = "transcript_with_timestamp_" +\ + filename = "./artefacts/transcript_with_timestamp_" +\ timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" with open(filename) as f: transcription_timestamp_text = f.read() @@ -162,7 +162,7 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" else: df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" - df.to_pickle(df_name) + df.to_pickle("./artefacts/" + df_name) my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] @@ -173,7 +173,7 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" else: mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" - pickle.dump(my_mappings, open(mappings_name, "wb")) + pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb")) # to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") ) @@ -187,27 +187,28 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True) - cat_1 = topic_times[0][0] - cat_1_name = topic_times[0][0] - cat_2_name = topic_times[1][0] + if len(topic_times) > 1: + cat_1 = topic_times[0][0] + cat_1_name = topic_times[0][0] + cat_2_name = topic_times[1][0] - # Scatter plot of topics - df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) - corpus = st.CorpusFromParsedDocuments( - df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' - ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) - html = st.produce_scattertext_explorer( - corpus, - category=cat_1, - category_name=cat_1_name, - not_category_name=cat_2_name, - minimum_term_frequency=0, pmi_threshold_coefficient=0, - width_in_pixels=1000, - transform=st.Scalers.dense_rank - ) - if real_time: - open('./artefacts/real_time_scatter_' + - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) - else: - open('./artefacts/scatter_' + - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) + # Scatter plot of topics + df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) + corpus = st.CorpusFromParsedDocuments( + df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' + ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) + html = st.produce_scattertext_explorer( + corpus, + category=cat_1, + category_name=cat_1_name, + not_category_name=cat_2_name, + minimum_term_frequency=0, pmi_threshold_coefficient=0, + width_in_pixels=1000, + transform=st.Scalers.dense_rank + ) + if real_time: + open('./artefacts/real_time_scatter_' + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) + else: + open('./artefacts/scatter_' + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) diff --git a/whisjax.py b/whisjax.py index 9e8ce4cf..53e16cd3 100644 --- a/whisjax.py +++ b/whisjax.py @@ -127,7 +127,7 @@ def main(): audio_filename = media_file logger.info("Finished extracting audio") - + logger.info("Transcribing") # Convert the audio to text using the OpenAI Whisper model pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, @@ -157,13 +157,14 @@ def main(): create_talk_diff_scatter_viz(NOW) # S3 : Push artefacts to S3 bucket + prefix = "./artefacts/" suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S") - files_to_upload = ["transcript_" + suffix + ".txt", - "transcript_with_timestamp_" + suffix + ".txt", - "df_" + suffix + ".pkl", - "wordcloud_" + suffix + ".png", - "mappings_" + suffix + ".pkl", - "scatter_" + suffix + ".html"] + files_to_upload = [prefix + "transcript_" + suffix + ".txt", + prefix + "transcript_with_timestamp_" + suffix + ".txt", + prefix + "df_" + suffix + ".pkl", + prefix + "wordcloud_" + suffix + ".png", + prefix + "mappings_" + suffix + ".pkl", + prefix + "scatter_" + suffix + ".html"] upload_files(files_to_upload) summarize(transcript_text, NOW, False, False) From 34f2795fa9ca211fd3aa818c666cc24af8fb82ac Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 11 Jul 2023 18:51:17 +0530 Subject: [PATCH 9/9] update log info --- server_executor_cleaned.py | 7 ++++--- utils/test.py | 0 2 files changed, 4 insertions(+), 3 deletions(-) delete mode 100644 utils/test.py diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index e0fb4cc3..83015983 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -39,6 +39,7 @@ def channel_send(channel, message): def get_transcription(frames): + print("Transcribing..") out_file = io.BytesIO() wf = wave.open(out_file, "wb") wf.setnchannels(CHANNELS) @@ -93,7 +94,7 @@ async def offer(request): def log_info(msg, *args): logger.info(pc_id + " " + msg, *args) - log_info("Created for %s", request.remote) + log_info("Created for " + request.remote) @pc.on("datachannel") def on_datachannel(channel): @@ -110,14 +111,14 @@ async def offer(request): @pc.on("connectionstatechange") async def on_connectionstatechange(): - log_info("Connection state is %s", pc.connectionState) + log_info("Connection state is " + pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @pc.on("track") def on_track(track): - log_info("Track %s received", track.kind) + log_info("Track " + track.kind + " received") pc.addTrack(AudioStreamTrack(relay.subscribe(track))) await pc.setRemoteDescription(offer) diff --git a/utils/test.py b/utils/test.py deleted file mode 100644 index e69de29b..00000000