diff --git a/client.py b/client.py index b2167d3b..519ccc26 100644 --- a/client.py +++ b/client.py @@ -5,11 +5,16 @@ import signal from aiortc.contrib.signaling import (add_signaling_arguments, create_signaling) -from utils.log_utils import logger +from utils.log_utils import LOGGER from stream_client import StreamClient +from typing import NoReturn - -async def main(): +async def main() -> NoReturn: + """ + Reflector's entry point to the python client for WebRTC streaming if not + using the browser based UI-application + :return: + """ parser = argparse.ArgumentParser(description="Data channels ping/pong") parser.add_argument( @@ -37,17 +42,17 @@ async def main(): async def shutdown(signal, loop): """Cleanup tasks tied to the service's shutdown.""" - logger.info(f"Received exit signal {signal.name}...") - logger.info("Closing database connections") - logger.info("Nacking outstanding messages") + LOGGER.info(f"Received exit signal {signal.name}...") + LOGGER.info("Closing database connections") + LOGGER.info("Nacking outstanding messages") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] - logger.info(f"Cancelling {len(tasks)} outstanding tasks") + LOGGER.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) - logger.info(f'{"Flushing metrics"}') + LOGGER.info(f'{"Flushing metrics"}') loop.stop() signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) diff --git a/reflector_dataclasses.py b/reflector_dataclasses.py index daac519c..e05396ae 100644 --- a/reflector_dataclasses.py +++ b/reflector_dataclasses.py @@ -1,3 +1,8 @@ +""" +Collection of data classes for streamlining and rigidly structuring +the input and output parameters of functions +""" + import datetime from dataclasses import dataclass from typing import List @@ -7,6 +12,10 @@ import av @dataclass class TitleSummaryInput: + """ + Data class for the input to generate title and summaries. + The outcome will be used to send query to the LLM for processing. + """ input_text = str transcribed_time = float prompt = str @@ -15,23 +24,28 @@ class TitleSummaryInput: def __init__(self, transcribed_time, input_text=""): self.input_text = input_text self.transcribed_time = transcribed_time - self.prompt = f""" - ### Human: - Create a JSON object as response. The JSON object must have 2 fields: - i) title and ii) summary. For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. + self.prompt = \ + f""" + ### Human: + Create a JSON object as response.The JSON object must have 2 fields: + i) title and ii) summary.For the title field,generate a short title + for the given text. For the summary field, summarize the given text + in three sentences. - {self.input_text} + {self.input_text} - ### Assistant: - """ + ### Assistant: + """ self.data = {"data": self.prompt} self.headers = {"Content-Type": "application/json"} @dataclass class IncrementalResult: + """ + Data class for the result of generating one title and summaries. + Defines how a single "topic" looks like. + """ title = str description = str transcript = str @@ -44,6 +58,10 @@ class IncrementalResult: @dataclass class TitleSummaryOutput: + """ + Data class for the result of all generated titles and summaries. + The result will be sent back to the client + """ cmd = str topics = List[IncrementalResult] @@ -59,6 +77,10 @@ class TitleSummaryOutput: @dataclass class ParseLLMResult: + """ + Data class to parse the result returned by the LLM while generating title + and summaries. The result will be sent back to the client. + """ description = str transcript = str timestamp = str @@ -66,7 +88,8 @@ class ParseLLMResult: def __init__(self, param: TitleSummaryInput, output: dict): self.transcript = param.input_text self.description = output.pop("summary") - self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) + self.timestamp = \ + str(datetime.timedelta(seconds=round(param.transcribed_time))) def get_result(self): return { @@ -78,6 +101,10 @@ class ParseLLMResult: @dataclass class TranscriptionInput: + """ + Data class to define the input to the transcription function + AudioFrames -> input + """ frames = List[av.audio.frame.AudioFrame] def __init__(self, frames): @@ -86,6 +113,10 @@ class TranscriptionInput: @dataclass class TranscriptionOutput: + """ + Dataclass to define the result of the transcription function. + The result will be sent back to the client + """ cmd = str result_text = str @@ -102,6 +133,10 @@ class TranscriptionOutput: @dataclass class FinalSummaryResult: + """ + Dataclass to define the result of the final summary function. + The result will be sent back to the client. + """ cmd = str final_summary = str duration = str @@ -117,3 +152,13 @@ class FinalSummaryResult: "duration": self.duration, "summary": self.final_summary } + + +class BlackListedMessages: + """ + Class to hold the blacklisted messages. These messages should be filtered + out and not sent back to the client as part of the transcription. + """ + messages = [" Thank you.", " See you next time!", + " Thank you for watching!", " Bye!", + " And that's what I'm talking about."] diff --git a/server.py b/server.py index 2ba07229..f45148cd 100644 --- a/server.py +++ b/server.py @@ -6,7 +6,7 @@ import os import uuid import wave from concurrent.futures import ThreadPoolExecutor -from typing import Any, NoReturn +from typing import Union, NoReturn import aiohttp_cors import av @@ -15,13 +15,13 @@ from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay from faster_whisper import WhisperModel -from loguru import logger from sortedcontainers import SortedDict from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\ TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\ - TranscriptionOutput -from utils.run_utils import config, run_in_executor + TranscriptionOutput, BlackListedMessages +from utils.run_utils import CONFIG, run_in_executor +from utils.log_utils import LOGGER pcs = set() relay = MediaRelay() @@ -36,24 +36,24 @@ audio_buffer = av.AudioFifo() executor = ThreadPoolExecutor() transcription_text = "" last_transcribed_time = 0.0 -LLM_MACHINE_IP = config["LLM"]["LLM_MACHINE_IP"] -LLM_MACHINE_PORT = config["LLM"]["LLM_MACHINE_PORT"] +LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"] +LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"] LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" incremental_responses = [] sorted_transcripts = SortedDict() -def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: +def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]: try: output = json.loads(response.json()["results"][0]["text"]) return ParseLLMResult(param, output) except Exception as e: - logger.info("Exception" + str(e)) + LOGGER.info("Exception" + str(e)) return None -def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]: - logger.info("Generating title and summary") +def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: + LOGGER.info("Generating title and summary") # TODO : Handle unexpected output formats from the model try: @@ -66,12 +66,12 @@ def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOut incremental_responses.append(result) return TitleSummaryOutput(incremental_responses) except Exception as e: - logger.info("Exception" + str(e)) + LOGGER.info("Exception" + str(e)) return None def channel_log(channel, t: str, message: str) -> NoReturn: - logger.info("channel(%s) %s %s" % (channel.label, t, message)) + LOGGER.info("channel(%s) %s %s" % (channel.label, t, message)) def channel_send(channel, message: str) -> NoReturn: @@ -79,7 +79,7 @@ def channel_send(channel, message: str) -> NoReturn: channel.send(message) -def channel_send_increment(channel, param: Any[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: +def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: if channel and param: message = param.get_result() channel.send(json.dumps(message)) @@ -89,11 +89,11 @@ def channel_send_transcript(channel) -> NoReturn: # channel_log(channel, ">", message) if channel: try: - least_time = sorted_transcripts.keys()[0] + least_time = next(iter(sorted_transcripts)) message = sorted_transcripts[least_time].get_result() if message: del sorted_transcripts[least_time] - if message["text"] not in blacklisted_messages: + if message["text"] not in BlackListedMessages.messages: channel.send(json.dumps(message)) # Due to exceptions if one of the earlier batches can't return # a transcript, we don't want to be stuck waiting for the result @@ -101,22 +101,21 @@ def channel_send_transcript(channel) -> NoReturn: else: if len(sorted_transcripts) >= 3: del sorted_transcripts[least_time] - except Exception as e: - logger.info("Exception", str(e)) - pass + except Exception as exception: + LOGGER.info("Exception", str(exception)) -def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]: - logger.info("Transcribing..") - sorted_transcripts[input_frames[0].time] = None +def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: + LOGGER.info("Transcribing..") + sorted_transcripts[input_frames.frames[0].time] = None # TODO: Find cleaner way, watch "no transcription" issue below # Passing IO objects instead of temporary files throws an error - # Passing ndarrays (typecasted with float) does not give any + # Passing ndarray (type casted with float) does not give any # transcription. Refer issue, # https://github.com/guillaumekln/faster-whisper/issues/369 - audiofilename = "test" + str(datetime.datetime.now()) - wf = wave.open(audiofilename, "wb") + audio_file = "test" + str(datetime.datetime.now()) + wf = wave.open(audio_file, "wb") wf.setnchannels(CHANNELS) wf.setframerate(RATE) wf.setsampwidth(2) @@ -129,12 +128,12 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti try: segments, _ = \ - model.transcribe(audiofilename, + model.transcribe(audio_file, language="en", beam_size=5, vad_filter=True, - vad_parameters=dict(min_silence_duration_ms=500)) - os.remove(audiofilename) + vad_parameters={"min_silence_duration_ms": 500}) + os.remove(audio_file) segments = list(segments) result_text = "" duration = 0.0 @@ -152,9 +151,8 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti last_transcribed_time += duration transcription_text += result_text - except Exception as e: - logger.info("Exception" + str(e)) - pass + except Exception as exception: + LOGGER.info("Exception" + str(exception)) result = TranscriptionOutput(result_text) sorted_transcripts[input_frames.frames[0].time] = result @@ -162,6 +160,11 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti def get_final_summary_response() -> FinalSummaryResult: + """ + Collate the incremental summaries generated so far and return as the final + summary + :return: + """ final_summary = "" # Collate inc summaries @@ -170,8 +173,9 @@ def get_final_summary_response() -> FinalSummaryResult: response = FinalSummaryResult(final_summary, last_transcribed_time) - with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: - f.write(json.dumps(incremental_responses)) + with open("./artefacts/meeting_titles_and_summaries.txt", "a", + encoding="utf-8") as file: + file.write(json.dumps(incremental_responses)) return response @@ -222,6 +226,11 @@ class AudioStreamTrack(MediaStreamTrack): async def offer(request: requests.Request) -> web.Response: + """ + Establish the WebRTC connection with the client + :param request: + :return: + """ params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -230,7 +239,7 @@ async def offer(request: requests.Request) -> web.Response: pcs.add(pc) def log_info(msg, *args) -> NoReturn: - logger.info(pc_id + " " + msg, *args) + LOGGER.info(pc_id + " " + msg, *args) log_info("Created for " + request.remote) @@ -272,15 +281,17 @@ async def offer(request: requests.Request) -> web.Response: return web.Response( content_type="application/json", text=json.dumps( - {"sdp": pc.localDescription.sdp, - "type": pc.localDescription.type} + { + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type + } ), ) -async def on_shutdown(app) -> NoReturn: - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) +async def on_shutdown(application: web.Application) -> NoReturn: + coroutines = [pc.close() for pc in pcs] + await asyncio.gather(*coroutines) pcs.clear() diff --git a/stream_client.py b/stream_client.py index ae22c6db..57f88229 100644 --- a/stream_client.py +++ b/stream_client.py @@ -9,8 +9,8 @@ import stamina from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) -from utils.log_utils import logger -from utils.run_utils import config +from utils.log_utils import LOGGER +from utils.run_utils import CONFIG class StreamClient: @@ -35,7 +35,7 @@ class StreamClient: self.time_start = None self.queue = asyncio.Queue() self.player = MediaPlayer( - ':' + str(config['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), + ':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), format='avfoundation', options={'channels': '2'}) @@ -74,7 +74,7 @@ class StreamClient: self.pcs.add(pc) def log_info(msg, *args): - logger.info(pc_id + " " + msg, *args) + LOGGER.info(pc_id + " " + msg, *args) @pc.on("connectionstatechange") async def on_connectionstatechange(): diff --git a/trials/finetuning/youtube_scraping.py b/trials/finetuning/youtube_scraping.py index b0892f47..be8b7e41 100644 --- a/trials/finetuning/youtube_scraping.py +++ b/trials/finetuning/youtube_scraping.py @@ -93,6 +93,6 @@ def generate_finetuning_dataset(video_ids): video_ids = ["yTnSEZIwnkU"] dataset = generate_finetuning_dataset(video_ids) -with open("finetuning_dataset.jsonl", "w") as f: +with open("finetuning_dataset.jsonl", "w", encoding="utf-8") as file: for example in dataset: - f.write(json.dumps(example) + "\n") + file.write(json.dumps(example) + "\n") diff --git a/trials/server/server_multithreaded.py b/trials/server/server_multithreaded.py index 4f7688a0..6739fbf6 100644 --- a/trials/server/server_multithreaded.py +++ b/trials/server/server_multithreaded.py @@ -16,10 +16,10 @@ from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline -from reflector.utils.log_utils import logger -from reflector.utils.run_utils import config, Mutex +from reflector.utils.log_utils import LOGGER +from reflector.utils.run_utils import CONFIG, Mutex -WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"] +WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"] pcs = set() relay = MediaRelay() data_channel = None @@ -127,7 +127,7 @@ async def offer(request: requests.Request): pcs.add(pc) def log_info(msg: str, *args): - logger.info(pc_id + " " + msg, *args) + LOGGER.info(pc_id + " " + msg, *args) log_info("Created for " + request.remote) diff --git a/trials/title_summary/incsum.py b/trials/title_summary/incsum.py index 5081d16c..571af77f 100644 --- a/trials/title_summary/incsum.py +++ b/trials/title_summary/incsum.py @@ -3,14 +3,14 @@ import sys # Observe the incremental summaries by performing summaries in chunks -with open("transcript.txt") as f: - transcription = f.read() +with open("transcript.txt", "r", encoding="utf-8") as file: + transcription = file.read() def split_text_file(filename, token_count): nlp = spacy.load('en_core_web_md') - with open(filename, 'r') as file: + with open(filename, 'r', encoding="utf-8") as file: text = file.read() doc = nlp(text) @@ -36,9 +36,9 @@ chunks = split_text_file("transcript.txt", MAX_CHUNK_LENGTH) print("Number of chunks", len(chunks)) # Write chunks to file to refer to input vs output, separated by blank lines -with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a") as f: +with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a", encoding="utf-8") as file: for c in chunks: - f.write(c + "\n\n") + file.write(c + "\n\n") # If we want to run only a certain model, type the option while running # ex. python incsum.py 1 => will run approach 1 @@ -78,9 +78,9 @@ if index == "1" or index is None: summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summaries.append(summary) - with open("bart-summaries.txt", "a") as f: + with open("bart-summaries.txt", "a", encoding="utf-8") as file: for summary in summaries: - f.write(summary + "\n\n") + file.write(summary + "\n\n") # Approach 2 if index == "2" or index is None: @@ -114,8 +114,8 @@ if index == "2" or index is None: summary_ids = output[0, input_length:] summary = tokenizer.decode(summary_ids, skip_special_tokens=True) summaries.append(summary) - with open("gptneo1.3B-summaries.txt", "a") as f: - f.write(summary + "\n\n") + with open("gptneo1.3B-summaries.txt", "a", encoding="utf-8") as file: + file.write(summary + "\n\n") # Approach 3 if index == "3" or index is None: @@ -152,6 +152,6 @@ if index == "3" or index is None: skip_special_tokens=True) summaries.append(summary) - with open("mpt-7b-summaries.txt", "a") as f: + with open("mpt-7b-summaries.txt", "a", encoding="utf-8") as file: for summary in summaries: - f.write(summary + "\n\n") + file.write(summary + "\n\n") diff --git a/trials/whisper-jax/whisjax.py b/trials/whisper-jax/whisjax.py index 2926fce0..fb4f5e1f 100644 --- a/trials/whisper-jax/whisjax.py +++ b/trials/whisper-jax/whisjax.py @@ -19,15 +19,15 @@ import yt_dlp as youtube_dl from whisper_jax import FlaxWhisperPipline from ...utils.file_utils import download_files, upload_files -from ...utils.log_utils import logger -from ...utils.run_utils import config +from ...utils.log_utils import LOGGER +from ...utils.run_utils import CONFIG from ...utils.text_utils import post_process_transcription, summarize from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) -WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"] NOW = datetime.now() if not os.path.exists('../../artefacts'): @@ -75,7 +75,7 @@ def main(): # Download the lowest resolution YouTube video # (since we're just interested in the audio). # It will be saved to the current directory. - logger.info("Downloading YouTube video at url: " + args.location) + LOGGER.info("Downloading YouTube video at url: " + args.location) # Create options for the download ydl_opts = { @@ -93,12 +93,12 @@ def main(): ydl.download([args.location]) media_file = "../artefacts/audio.mp3" - logger.info("Saved downloaded YouTube video to: " + media_file) + LOGGER.info("Saved downloaded YouTube video to: " + media_file) else: # XXX - Download file using urllib, check if file is # audio/video using python-magic - logger.info(f"Downloading file at url: {args.location}") - logger.info(" XXX - This method hasn't been implemented yet.") + LOGGER.info(f"Downloading file at url: {args.location}") + LOGGER.info(" XXX - This method hasn't been implemented yet.") elif url.scheme == '': media_file = url.path # If file is not present locally, take it from S3 bucket @@ -119,7 +119,7 @@ def main(): audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name video.audio.write_audiofile(audio_filename, logger=None) - logger.info(f"Extracting audio to: {audio_filename}") + LOGGER.info(f"Extracting audio to: {audio_filename}") # Handle audio only file except Exception: audio = moviepy.editor.AudioFileClip(media_file) @@ -129,14 +129,14 @@ def main(): else: audio_filename = media_file - logger.info("Finished extracting audio") - logger.info("Transcribing") + LOGGER.info("Finished extracting audio") + LOGGER.info("Transcribing") # Convert the audio to text using the OpenAI Whisper model pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, batch_size=16) whisper_result = pipeline(audio_filename, return_timestamps=True) - logger.info("Finished transcribing file") + LOGGER.info("Finished transcribing file") whisper_result = post_process_transcription(whisper_result) @@ -153,10 +153,10 @@ def main(): "w") as transcript_file_timestamps: transcript_file_timestamps.write(str(whisper_result)) - logger.info("Creating word cloud") + LOGGER.info("Creating word cloud") create_wordcloud(NOW) - logger.info("Performing talk-diff and talk-diff visualization") + LOGGER.info("Performing talk-diff and talk-diff visualization") create_talk_diff_scatter_viz(NOW) # S3 : Push artefacts to S3 bucket @@ -172,7 +172,7 @@ def main(): summarize(transcript_text, NOW, False, False) - logger.info("Summarization completed") + LOGGER.info("Summarization completed") # Summarization takes a lot of time, so do this separately at the end files_to_upload = [prefix + "summary_" + suffix + ".txt"] diff --git a/trials/whisper-jax/whisjax_realtime.py b/trials/whisper-jax/whisjax_realtime.py index 5f269c18..ec822854 100644 --- a/trials/whisper-jax/whisjax_realtime.py +++ b/trials/whisper-jax/whisjax_realtime.py @@ -11,12 +11,12 @@ from termcolor import colored from whisper_jax import FlaxWhisperPipline from ...utils.file_utils import upload_files -from ...utils.log_utils import logger -from ...utils.run_utils import config +from ...utils.log_utils import LOGGER +from ...utils.run_utils import CONFIG from ...utils.text_utils import post_process_transcription, summarize from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud -WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"] FRAMES_PER_BUFFER = 8000 FORMAT = pyaudio.paInt16 @@ -31,7 +31,7 @@ def main(): AUDIO_DEVICE_ID = -1 for i in range(p.get_device_count()): if p.get_device_info_by_index(i)["name"] == \ - config["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: + CONFIG["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: AUDIO_DEVICE_ID = i audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) stream = p.open( @@ -44,7 +44,7 @@ def main(): ) pipeline = FlaxWhisperPipline("openai/whisper-" + - config["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"], + CONFIG["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"], dtype=jnp.float16, batch_size=16) @@ -106,23 +106,26 @@ def main(): " | Transcribed duration: " + str(duration), "yellow")) - except Exception as e: - print(e) + except Exception as exception: + print(str(exception)) finally: - with open("real_time_transcript_" + - NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: - f.write(transcription) + with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + + ".txt", "w", encoding="utf-8") as file: + file.write(transcription) + with open("real_time_transcript_with_timestamp_" + - NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w", + encoding="utf-8") as file: transcript_with_timestamp["text"] = transcription - f.write(str(transcript_with_timestamp)) + file.write(str(transcript_with_timestamp)) - transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) + transcript_with_timestamp = \ + post_process_transcription(transcript_with_timestamp) - logger.info("Creating word cloud") + LOGGER.info("Creating word cloud") create_wordcloud(NOW, True) - logger.info("Performing talk-diff and talk-diff visualization") + LOGGER.info("Performing talk-diff and talk-diff visualization") create_talk_diff_scatter_viz(NOW, True) # S3 : Push artefacts to S3 bucket @@ -137,7 +140,7 @@ def main(): summarize(transcript_with_timestamp["text"], NOW, True, True) - logger.info("Summarization completed") + LOGGER.info("Summarization completed") # Summarization takes a lot of time, so do this separately at the end files_to_upload = ["real_time_summary_" + suffix + ".txt"] diff --git a/utils/file_utils.py b/utils/file_utils.py index 596bfe7f..db294c6e 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -1,16 +1,21 @@ +""" +Utility file for file handling related functions, including file downloads and +uploads to cloud storage +""" + import sys import boto3 import botocore -from .log_utils import logger -from .run_utils import config +from .log_utils import LOGGER +from .run_utils import CONFIG -BUCKET_NAME = config["AWS"]["BUCKET_NAME"] +BUCKET_NAME = CONFIG["AWS"]["BUCKET_NAME"] s3 = boto3.client('s3', - aws_access_key_id=config["AWS"]["AWS_ACCESS_KEY"], - aws_secret_access_key=config["AWS"]["AWS_SECRET_KEY"]) + aws_access_key_id=CONFIG["AWS"]["AWS_ACCESS_KEY"], + aws_secret_access_key=CONFIG["AWS"]["AWS_SECRET_KEY"]) def upload_files(files_to_upload): @@ -19,12 +24,12 @@ def upload_files(files_to_upload): :param files_to_upload: List of files to upload :return: None """ - for KEY in files_to_upload: - logger.info("Uploading file " + KEY) + for key in files_to_upload: + LOGGER.info("Uploading file " + key) try: - s3.upload_file(KEY, BUCKET_NAME, KEY) - except botocore.exceptions.ClientError as e: - print(e.response) + s3.upload_file(key, BUCKET_NAME, key) + except botocore.exceptions.ClientError as exception: + print(exception.response) def download_files(files_to_download): @@ -33,12 +38,12 @@ def download_files(files_to_download): :param files_to_download: List of files to download :return: None """ - for KEY in files_to_download: - logger.info("Downloading file " + KEY) + for key in files_to_download: + LOGGER.info("Downloading file " + key) try: - s3.download_file(BUCKET_NAME, KEY, KEY) - except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] == "404": + s3.download_file(BUCKET_NAME, key, key) + except botocore.exceptions.ClientError as exception: + if exception.response['Error']['Code'] == "404": print("The object does not exist.") else: raise diff --git a/utils/format_output.py b/utils/format_output.py index 4f026ce2..c46b90ba 100644 --- a/utils/format_output.py +++ b/utils/format_output.py @@ -1,13 +1,24 @@ +""" +Utility function to format the artefacts created during Reflector run +""" + import json -with open("../artefacts/meeting_titles_and_summaries.txt", "r") as f: +with open("../artefacts/meeting_titles_and_summaries.txt", "r", + encoding='utf-8') as f: outputs = f.read() outputs = json.loads(outputs) -transcript_file = open("../artefacts/meeting_transcript.txt", "a") -title_desc_file = open("../artefacts/meeting_title_description.txt", "a") -summary_file = open("../artefacts/meeting_summary.txt", "a") +transcript_file = open("../artefacts/meeting_transcript.txt", + "a", + encoding='utf-8') +title_desc_file = open("../artefacts/meeting_title_description.txt", + "a", + encoding='utf-8') +summary_file = open("../artefacts/meeting_summary.txt", + "a", + encoding='utf-8') for item in outputs["topics"]: transcript_file.write(item["transcript"]) diff --git a/utils/log_utils.py b/utils/log_utils.py index f665f5da..84cbe3fe 100644 --- a/utils/log_utils.py +++ b/utils/log_utils.py @@ -1,7 +1,15 @@ +""" +Utility file for logging +""" + import loguru class SingletonLogger: + """ + Use Singleton design pattern to create a logger object and share it + across the entire project + """ __instance = None @staticmethod @@ -15,4 +23,4 @@ class SingletonLogger: return SingletonLogger.__instance -logger = SingletonLogger.get_logger() +LOGGER = SingletonLogger.get_logger() diff --git a/utils/run_utils.py b/utils/run_utils.py index bb2b6348..2271fc19 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -1,3 +1,7 @@ +""" +Utility file for server side asynchronous task running and config objects +""" + import asyncio import configparser import contextlib @@ -7,6 +11,9 @@ from typing import ContextManager, Generic, TypeVar class ReflectorConfig: + """ + Create a single config object to share across the project + """ __config = None @staticmethod @@ -17,7 +24,7 @@ class ReflectorConfig: return ReflectorConfig.__config -config = ReflectorConfig.get_config() +CONFIG = ReflectorConfig.get_config() def run_in_executor(func, *args, executor=None, **kwargs): diff --git a/utils/text_utils.py b/utils/text_utils.py index 9ce7b1c1..8fb5ba10 100644 --- a/utils/text_utils.py +++ b/utils/text_utils.py @@ -1,3 +1,7 @@ +""" +Utility file for all text processing related functionalities +""" + import nltk import torch from nltk.corpus import stopwords @@ -6,8 +10,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import BartForConditionalGeneration, BartTokenizer -from log_utils import logger -from run_utils import config +from log_utils import LOGGER +from run_utils import CONFIG nltk.download('punkt', quiet=True) @@ -32,6 +36,12 @@ def compute_similarity(sent1, sent2): def remove_almost_alike_sentences(sentences, threshold=0.7): + """ + Filter sentences that are similar beyond a set threshold + :param sentences: + :param threshold: + :return: + """ num_sentences = len(sentences) removed_indices = set() @@ -62,6 +72,11 @@ def remove_almost_alike_sentences(sentences, threshold=0.7): def remove_outright_duplicate_sentences_from_chunk(chunk): + """ + Remove repetitive sentences + :param chunk: + :return: + """ chunk_text = chunk["text"] sentences = nltk.sent_tokenize(chunk_text) nonduplicate_sentences = list(dict.fromkeys(sentences)) @@ -69,6 +84,12 @@ def remove_outright_duplicate_sentences_from_chunk(chunk): def remove_whisper_repetitive_hallucination(nonduplicate_sentences): + """ + Remove sentences that are repeated as a result of Whisper + hallucinations + :param nonduplicate_sentences: + :return: + """ chunk_sentences = [] for sent in nonduplicate_sentences: @@ -91,6 +112,11 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): def post_process_transcription(whisper_result): + """ + Parent function to perform post-processing on the transcription result + :param whisper_result: + :return: + """ transcript_text = "" for chunk in whisper_result["chunks"]: nonduplicate_sentences = \ @@ -121,9 +147,9 @@ def summarize_chunks(chunks, tokenizer, model): with torch.no_grad(): summary_ids = \ model.generate(input_ids, - num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]), + num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]), + max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]), early_stopping=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) @@ -132,7 +158,7 @@ def summarize_chunks(chunks, tokenizer, model): def chunk_text(text, - max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])): + max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])): """ Split text into smaller chunks. :param text: Text to be chunked @@ -154,14 +180,22 @@ def chunk_text(text, def summarize(transcript_text, timestamp, real_time=False, - chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): + chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): + """ + Summarize the given text either as a whole or as chunks as needed + :param transcript_text: + :param timestamp: + :param real_time: + :param chunk_summarize: + :return: + """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"] + summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"] if not summary_model: summary_model = "facebook/bart-large-cnn" # Summarize the generated transcript using the BART model - logger.info(f"Loading BART model: {summary_model}") + LOGGER.info(f"Loading BART model: {summary_model}") tokenizer = BartTokenizer.from_pretrained(summary_model) model = BartForConditionalGeneration.from_pretrained(summary_model) model = model.to(device) @@ -171,7 +205,7 @@ def summarize(transcript_text, timestamp, output_file = "real_time_" + output_file if chunk_summarize != "YES": - max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) + max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) inputs = tokenizer. \ batch_encode_plus([transcript_text], truncation=True, padding='longest', @@ -180,8 +214,8 @@ def summarize(transcript_text, timestamp, inputs = inputs.to(device) with torch.no_grad(): - num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"]) - max_length = int(config["SUMMARIZER"]["MAX_LENGTH"]) + num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]) + max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]) summaries = model.generate(inputs['input_ids'], num_beams=num_beans, length_penalty=2.0, @@ -194,16 +228,16 @@ def summarize(transcript_text, timestamp, clean_up_tokenization_spaces=False) for summary in summaries] summary = " ".join(decoded_summaries) - with open("./artefacts/" + output_file, 'w') as f: - f.write(summary.strip() + "\n") + with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file: + file.write(summary.strip() + "\n") else: - logger.info("Breaking transcript into smaller chunks") + LOGGER.info("Breaking transcript into smaller chunks") chunks = chunk_text(transcript_text) - logger.info(f"Transcript broken into {len(chunks)} " + LOGGER.info(f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words") - logger.info(f"Writing summary text to: {output_file}") + LOGGER.info(f"Writing summary text to: {output_file}") with open(output_file, 'w') as f: summaries = summarize_chunks(chunks, tokenizer, model) for summary in summaries: diff --git a/utils/viz_utils.py b/utils/viz_utils.py index d7debd0c..498c7cf7 100644 --- a/utils/viz_utils.py +++ b/utils/viz_utils.py @@ -1,3 +1,7 @@ +""" +Utility file for all visualization related functions +""" + import ast import collections import os @@ -81,8 +85,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): else: filename = "./artefacts/transcript_with_timestamp_" + \ timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" - with open(filename) as f: - transcription_timestamp_text = f.read() + with open(filename) as file: + transcription_timestamp_text = file.read() res = ast.literal_eval(transcription_timestamp_text) chunks = res["chunks"]