From b892fc056267af26606e886e27355049b69a7b8f Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 25 Jul 2023 22:55:17 +0530 Subject: [PATCH 1/3] add data classes and typing --- reflector_dataclasses.py | 119 +++++++++++++++++++++++++++++++++++++++ server.py | 98 +++++++++++++------------------- 2 files changed, 158 insertions(+), 59 deletions(-) create mode 100644 reflector_dataclasses.py diff --git a/reflector_dataclasses.py b/reflector_dataclasses.py new file mode 100644 index 00000000..cfc4da6a --- /dev/null +++ b/reflector_dataclasses.py @@ -0,0 +1,119 @@ +import datetime +from dataclasses import dataclass +from typing import List + +import av + + +@dataclass +class TitleSummaryInput: + input_text = str + transcribed_time = float + prompt = str + data = dict + + def __init__(self, transcribed_time, input_text=""): + self.input_text = input_text + self.transcribed_time = transcribed_time + self.prompt = f""" + ### Human: + Create a JSON object as response. The JSON object must have 2 fields: + i) title and ii) summary. For the title field,generate a short title + for the given text. For the summary field, summarize the given text + in three sentences. + + {self.input_text} + + ### Assistant: + """ + self.data = {"data": self.prompt} + self.headers = {"Content-Type": "application/json"} + + +@dataclass +class IncrementalResponse: + title = str + description = str + transcript = str + + def __init__(self, title, desc, transcript): + self.title = title + self.description = desc + self.transcript = transcript + + +@dataclass +class TitleSummaryOutput: + cmd = str + topics = List[IncrementalResponse] + + def __init__(self, inc_responses): + self.topics = inc_responses + + def get_response(self): + return { + "cmd": self.cmd, + "topics": self.topics + } + + +@dataclass +class ParseLLMResult: + description = str + transcript = str + timestamp = str + + def __init__(self, param: TitleSummaryInput, output: dict): + self.transcript = param.input_text + self.description = output.pop("summary") + self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) + + def get_result(self): + return { + "description": self.description, + "transcript": self.transcript, + "timestamp": self.timestamp + } + + +@dataclass +class TranscriptionInput: + frames = List[av.audio.frame.AudioFrame] + + def __init__(self, frames): + self.frames = frames + + +@dataclass +class TranscriptionOutput: + cmd = str + result_text = str + + def __init__(self, result_text): + self.cmd = "SHOW_TRANSCRIPTION" + self.result_text = result_text + + def get_response(self): + return { + "cmd": self.cmd, + "text": self.result_text + } + + +@dataclass +class FinalSummaryResponse: + cmd = str + final_summary = str + duration = str + + def __init__(self, final_summary, time): + self.duration = str(datetime.timedelta(seconds=round(time))) + self.final_summary = final_summary + self.cmd = "" + + def get_response(self): + return { + "cmd": self.cmd, + "duration": self.duration, + "summary": self.final_summary + } diff --git a/server.py b/server.py index 55066eef..1c556a93 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,7 @@ import os import uuid import wave from concurrent.futures import ThreadPoolExecutor +from typing import Any import aiohttp_cors import requests @@ -17,7 +18,9 @@ from faster_whisper import WhisperModel from loguru import logger from sortedcontainers import SortedDict -from utils.run_utils import run_in_executor, config +from reflector_dataclasses import FinalSummaryResponse, ParseLLMResult, TitleSummaryInput, TitleSummaryOutput, \ + TranscriptionInput, TranscriptionOutput +from utils.run_utils import config, run_in_executor pcs = set() relay = MediaRelay() @@ -43,49 +46,30 @@ blacklisted_messages = [" Thank you.", " See you next time!", " And that's what I'm talking about."] -def get_title_and_summary(llm_input_text, last_timestamp): +def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: + try: + output = json.loads(response.json()["results"][0]["text"]) + return ParseLLMResult(param, output).get_result() + except Exception as e: + logger.info("Exception" + str(e)) + return None + + +def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]: logger.info("Generating title and summary") - # output = llm.generate(prompt) - - # Use monadical-ml to fire this query to an LLM and get result - headers = { - "Content-Type": "application/json" - } - - prompt = f""" - ### Human: - Create a JSON object as response. The JSON object must have 2 fields: - i) title and ii) summary. For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. - - {llm_input_text} - - ### Assistant: - """ - - data = { - "prompt": prompt - } # TODO : Handle unexpected output formats from the model try: - response = requests.post(LLM_URL, headers=headers, json=data) - output = json.loads(response.json()["results"][0]["text"]) - output["description"] = output.pop("summary") - output["transcript"] = llm_input_text - output["timestamp"] = \ - str(datetime.timedelta(seconds=round(last_timestamp))) - incremental_responses.append(output) - result = { - "cmd": "UPDATE_TOPICS", - "topics": incremental_responses, - } - + response = requests.post(LLM_URL, + headers=param.headers, + json=param.data) + output = parse_llm_output(param, response) + if output: + incremental_responses.append(output) + return TitleSummaryOutput(incremental_responses).get_response() except Exception as e: logger.info("Exception" + str(e)) - result = None - return result + return None def channel_log(channel, t, message): @@ -123,11 +107,11 @@ def channel_send_transcript(channel): pass -def get_transcription(frames): +def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]: logger.info("Transcribing..") - sorted_transcripts[frames[0].time] = None + sorted_transcripts[input_frames[0].time] = None - # TODO: + # TODO: Find cleaner way, watch "no transcription" issue below # Passing IO objects instead of temporary files throws an error # Passing ndarrays (typecasted with float) does not give any # transcription. Refer issue, @@ -138,7 +122,7 @@ def get_transcription(frames): wf.setframerate(RATE) wf.setsampwidth(2) - for frame in frames: + for frame in input_frames.frames: wf.writeframes(b"".join(frame.to_ndarray())) wf.close() @@ -173,30 +157,23 @@ def get_transcription(frames): logger.info("Exception" + str(e)) pass - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": result_text - } - sorted_transcripts[frames[0].time] = result + result = TranscriptionOutput(result_text).get_response() + sorted_transcripts[input_frames.frames[0].time] = result return result -def get_final_summary_response(): +def get_final_summary_response() -> Any[None, FinalSummaryResponse]: final_summary = "" # Collate inc summaries for topic in incremental_responses: final_summary += topic["description"] - response = { - "cmd": "DISPLAY_FINAL_SUMMARY", - "duration": str(datetime.timedelta( - seconds=round(last_transcribed_time))), - "summary": final_summary - } + response = FinalSummaryResponse(final_summary, last_transcribed_time).get_response() with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: f.write(json.dumps(incremental_responses)) + return response @@ -218,7 +195,9 @@ class AudioStreamTrack(MediaStreamTrack): if local_frames := audio_buffer.read_many(256 * 960, partial=False): whisper_result = run_in_executor( - get_transcription, local_frames, executor=executor + get_transcription, + TranscriptionInput(local_frames), + executor=executor ) whisper_result.add_done_callback( lambda f: channel_send_transcript(data_channel) @@ -226,12 +205,13 @@ class AudioStreamTrack(MediaStreamTrack): else None ) - if len(transcription_text) > 750: + if len(transcription_text) > 25: llm_input_text = transcription_text transcription_text = "" + param = TitleSummaryInput(input_text=llm_input_text, + transcribed_time=last_transcribed_time) llm_result = run_in_executor(get_title_and_summary, - llm_input_text, - last_transcribed_time, + param, executor=executor) llm_result.add_done_callback( lambda f: channel_send_increment(data_channel, @@ -332,4 +312,4 @@ if __name__ == "__main__": offer_resource = cors.add(app.router.add_resource("/offer")) cors.add(offer_resource.add_route("POST", offer)) app.on_shutdown.append(on_shutdown) - web.run_app(app, access_log=None, host=args.host, port=args.port) + web.run_app(app, access_log=None, host=args.host, port=args.port) From c970fc89dd5e36fc98151b5bff84720760bc0e39 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 26 Jul 2023 09:59:25 +0530 Subject: [PATCH 2/3] code style updates --- reflector_dataclasses.py | 12 ++--- server.py | 64 +++++++++++++------------- stream_client.py | 2 +- trials/server/server_multithreaded.py | 2 +- trials/whisper-jax/whisjax.py | 2 +- trials/whisper-jax/whisjax_realtime.py | 6 +-- utils/file_utils.py | 6 +-- utils/text_utils.py | 16 +++---- 8 files changed, 54 insertions(+), 56 deletions(-) diff --git a/reflector_dataclasses.py b/reflector_dataclasses.py index cfc4da6a..daac519c 100644 --- a/reflector_dataclasses.py +++ b/reflector_dataclasses.py @@ -31,7 +31,7 @@ class TitleSummaryInput: @dataclass -class IncrementalResponse: +class IncrementalResult: title = str description = str transcript = str @@ -45,12 +45,12 @@ class IncrementalResponse: @dataclass class TitleSummaryOutput: cmd = str - topics = List[IncrementalResponse] + topics = List[IncrementalResult] def __init__(self, inc_responses): self.topics = inc_responses - def get_response(self): + def get_result(self): return { "cmd": self.cmd, "topics": self.topics @@ -93,7 +93,7 @@ class TranscriptionOutput: self.cmd = "SHOW_TRANSCRIPTION" self.result_text = result_text - def get_response(self): + def get_result(self): return { "cmd": self.cmd, "text": self.result_text @@ -101,7 +101,7 @@ class TranscriptionOutput: @dataclass -class FinalSummaryResponse: +class FinalSummaryResult: cmd = str final_summary = str duration = str @@ -111,7 +111,7 @@ class FinalSummaryResponse: self.final_summary = final_summary self.cmd = "" - def get_response(self): + def get_result(self): return { "cmd": self.cmd, "duration": self.duration, diff --git a/server.py b/server.py index 1c556a93..2ba07229 100644 --- a/server.py +++ b/server.py @@ -6,20 +6,21 @@ import os import uuid import wave from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, NoReturn import aiohttp_cors +import av import requests from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay -from av import AudioFifo from faster_whisper import WhisperModel from loguru import logger from sortedcontainers import SortedDict -from reflector_dataclasses import FinalSummaryResponse, ParseLLMResult, TitleSummaryInput, TitleSummaryOutput, \ - TranscriptionInput, TranscriptionOutput +from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\ + TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\ + TranscriptionOutput from utils.run_utils import config, run_in_executor pcs = set() @@ -31,25 +32,21 @@ model = WhisperModel("tiny", device="cpu", CHANNELS = 2 RATE = 48000 -audio_buffer = AudioFifo() +audio_buffer = av.AudioFifo() executor = ThreadPoolExecutor() transcription_text = "" last_transcribed_time = 0.0 -LLM_MACHINE_IP = config["DEFAULT"]["LLM_MACHINE_IP"] -LLM_MACHINE_PORT = config["DEFAULT"]["LLM_MACHINE_PORT"] +LLM_MACHINE_IP = config["LLM"]["LLM_MACHINE_IP"] +LLM_MACHINE_PORT = config["LLM"]["LLM_MACHINE_PORT"] LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" incremental_responses = [] sorted_transcripts = SortedDict() -blacklisted_messages = [" Thank you.", " See you next time!", - " Thank you for watching!", " Bye!", - " And that's what I'm talking about."] - def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: try: output = json.loads(response.json()["results"][0]["text"]) - return ParseLLMResult(param, output).get_result() + return ParseLLMResult(param, output) except Exception as e: logger.info("Exception" + str(e)) return None @@ -65,33 +62,35 @@ def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOut json=param.data) output = parse_llm_output(param, response) if output: - incremental_responses.append(output) - return TitleSummaryOutput(incremental_responses).get_response() + result = output.get_result() + incremental_responses.append(result) + return TitleSummaryOutput(incremental_responses) except Exception as e: logger.info("Exception" + str(e)) return None -def channel_log(channel, t, message): +def channel_log(channel, t: str, message: str) -> NoReturn: logger.info("channel(%s) %s %s" % (channel.label, t, message)) -def channel_send(channel, message): +def channel_send(channel, message: str) -> NoReturn: if channel: channel.send(message) -def channel_send_increment(channel, message): - if channel and message: +def channel_send_increment(channel, param: Any[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: + if channel and param: + message = param.get_result() channel.send(json.dumps(message)) -def channel_send_transcript(channel): +def channel_send_transcript(channel) -> NoReturn: # channel_log(channel, ">", message) if channel: try: least_time = sorted_transcripts.keys()[0] - message = sorted_transcripts[least_time] + message = sorted_transcripts[least_time].get_result() if message: del sorted_transcripts[least_time] if message["text"] not in blacklisted_messages: @@ -157,19 +156,19 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti logger.info("Exception" + str(e)) pass - result = TranscriptionOutput(result_text).get_response() + result = TranscriptionOutput(result_text) sorted_transcripts[input_frames.frames[0].time] = result return result -def get_final_summary_response() -> Any[None, FinalSummaryResponse]: +def get_final_summary_response() -> FinalSummaryResult: final_summary = "" # Collate inc summaries for topic in incremental_responses: final_summary += topic["description"] - response = FinalSummaryResponse(final_summary, last_transcribed_time).get_response() + response = FinalSummaryResult(final_summary, last_transcribed_time) with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: f.write(json.dumps(incremental_responses)) @@ -188,7 +187,7 @@ class AudioStreamTrack(MediaStreamTrack): super().__init__() self.track = track - async def recv(self): + async def recv(self) -> av.audio.frame.AudioFrame: global transcription_text frame = await self.track.recv() audio_buffer.write(frame) @@ -222,7 +221,7 @@ class AudioStreamTrack(MediaStreamTrack): return frame -async def offer(request): +async def offer(request: requests.Request) -> web.Response: params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -230,40 +229,39 @@ async def offer(request): pc_id = "PeerConnection(%s)" % uuid.uuid4() pcs.add(pc) - def log_info(msg, *args): + def log_info(msg, *args) -> NoReturn: logger.info(pc_id + " " + msg, *args) log_info("Created for " + request.remote) @pc.on("datachannel") - def on_datachannel(channel): + def on_datachannel(channel) -> NoReturn: global data_channel data_channel = channel channel_log(channel, "-", "created by remote party") @channel.on("message") - def on_message(message): + def on_message(message: str) -> NoReturn: channel_log(channel, "<", message) if json.loads(message)["cmd"] == "STOP": - # Place holder final summary + # Placeholder final summary response = get_final_summary_response() channel_send_increment(data_channel, response) # To-do Add code to stop connection from server side here # But have to handshake with client once - # pc.close() if isinstance(message, str) and message.startswith("ping"): channel_send(channel, "pong" + message[4:]) @pc.on("connectionstatechange") - async def on_connectionstatechange(): + async def on_connectionstatechange() -> NoReturn: log_info("Connection state is " + pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) @pc.on("track") - def on_track(track): + def on_track(track) -> NoReturn: log_info("Track " + track.kind + " received") pc.addTrack(AudioStreamTrack(relay.subscribe(track))) @@ -280,7 +278,7 @@ async def offer(request): ) -async def on_shutdown(app): +async def on_shutdown(app) -> NoReturn: coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() diff --git a/stream_client.py b/stream_client.py index c2238ee5..ae22c6db 100644 --- a/stream_client.py +++ b/stream_client.py @@ -35,7 +35,7 @@ class StreamClient: self.time_start = None self.queue = asyncio.Queue() self.player = MediaPlayer( - ':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), + ':' + str(config['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), format='avfoundation', options={'channels': '2'}) diff --git a/trials/server/server_multithreaded.py b/trials/server/server_multithreaded.py index 1c5e75d7..4f7688a0 100644 --- a/trials/server/server_multithreaded.py +++ b/trials/server/server_multithreaded.py @@ -19,7 +19,7 @@ from whisper_jax import FlaxWhisperPipline from reflector.utils.log_utils import logger from reflector.utils.run_utils import config, Mutex -WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] +WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"] pcs = set() relay = MediaRelay() data_channel = None diff --git a/trials/whisper-jax/whisjax.py b/trials/whisper-jax/whisjax.py index eb87629d..2926fce0 100644 --- a/trials/whisper-jax/whisjax.py +++ b/trials/whisper-jax/whisjax.py @@ -27,7 +27,7 @@ from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) -WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] NOW = datetime.now() if not os.path.exists('../../artefacts'): diff --git a/trials/whisper-jax/whisjax_realtime.py b/trials/whisper-jax/whisjax_realtime.py index efb39461..5f269c18 100644 --- a/trials/whisper-jax/whisjax_realtime.py +++ b/trials/whisper-jax/whisjax_realtime.py @@ -16,7 +16,7 @@ from ...utils.run_utils import config from ...utils.text_utils import post_process_transcription, summarize from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud -WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] FRAMES_PER_BUFFER = 8000 FORMAT = pyaudio.paInt16 @@ -31,7 +31,7 @@ def main(): AUDIO_DEVICE_ID = -1 for i in range(p.get_device_count()): if p.get_device_info_by_index(i)["name"] == \ - config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: + config["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: AUDIO_DEVICE_ID = i audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) stream = p.open( @@ -44,7 +44,7 @@ def main(): ) pipeline = FlaxWhisperPipline("openai/whisper-" + - config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], + config["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"], dtype=jnp.float16, batch_size=16) diff --git a/utils/file_utils.py b/utils/file_utils.py index cc9a9ded..596bfe7f 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -6,11 +6,11 @@ import botocore from .log_utils import logger from .run_utils import config -BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"] +BUCKET_NAME = config["AWS"]["BUCKET_NAME"] s3 = boto3.client('s3', - aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], - aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"]) + aws_access_key_id=config["AWS"]["AWS_ACCESS_KEY"], + aws_secret_access_key=config["AWS"]["AWS_SECRET_KEY"]) def upload_files(files_to_upload): diff --git a/utils/text_utils.py b/utils/text_utils.py index 25126b34..9ce7b1c1 100644 --- a/utils/text_utils.py +++ b/utils/text_utils.py @@ -121,9 +121,9 @@ def summarize_chunks(chunks, tokenizer, model): with torch.no_grad(): summary_ids = \ model.generate(input_ids, - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), + num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), + max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]), early_stopping=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) @@ -132,7 +132,7 @@ def summarize_chunks(chunks, tokenizer, model): def chunk_text(text, - max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): + max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])): """ Split text into smaller chunks. :param text: Text to be chunked @@ -154,9 +154,9 @@ def chunk_text(text, def summarize(transcript_text, timestamp, real_time=False, - chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): + chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summary_model = config["DEFAULT"]["SUMMARY_MODEL"] + summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"] if not summary_model: summary_model = "facebook/bart-large-cnn" @@ -171,7 +171,7 @@ def summarize(transcript_text, timestamp, output_file = "real_time_" + output_file if chunk_summarize != "YES": - max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]) + max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) inputs = tokenizer. \ batch_encode_plus([transcript_text], truncation=True, padding='longest', @@ -180,8 +180,8 @@ def summarize(transcript_text, timestamp, inputs = inputs.to(device) with torch.no_grad(): - num_beans = int(config["DEFAULT"]["BEAM_SIZE"]) - max_length = int(config["DEFAULT"]["MAX_LENGTH"]) + num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"]) + max_length = int(config["SUMMARIZER"]["MAX_LENGTH"]) summaries = model.generate(inputs['input_ids'], num_beams=num_beans, length_penalty=2.0, From e512b4dca51983acde40799d59d5c27648dc3cb2 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Wed, 26 Jul 2023 11:28:14 +0530 Subject: [PATCH 3/3] flake8 / pylint updates --- client.py | 21 ++++--- reflector_dataclasses.py | 65 ++++++++++++++++--- server.py | 87 +++++++++++++++----------- stream_client.py | 8 +-- trials/finetuning/youtube_scraping.py | 4 +- trials/server/server_multithreaded.py | 8 +-- trials/title_summary/incsum.py | 22 +++---- trials/whisper-jax/whisjax.py | 28 ++++----- trials/whisper-jax/whisjax_realtime.py | 35 ++++++----- utils/file_utils.py | 35 ++++++----- utils/format_output.py | 19 ++++-- utils/log_utils.py | 10 ++- utils/run_utils.py | 9 ++- utils/text_utils.py | 66 ++++++++++++++----- utils/viz_utils.py | 8 ++- 15 files changed, 279 insertions(+), 146 deletions(-) diff --git a/client.py b/client.py index b2167d3b..519ccc26 100644 --- a/client.py +++ b/client.py @@ -5,11 +5,16 @@ import signal from aiortc.contrib.signaling import (add_signaling_arguments, create_signaling) -from utils.log_utils import logger +from utils.log_utils import LOGGER from stream_client import StreamClient +from typing import NoReturn - -async def main(): +async def main() -> NoReturn: + """ + Reflector's entry point to the python client for WebRTC streaming if not + using the browser based UI-application + :return: + """ parser = argparse.ArgumentParser(description="Data channels ping/pong") parser.add_argument( @@ -37,17 +42,17 @@ async def main(): async def shutdown(signal, loop): """Cleanup tasks tied to the service's shutdown.""" - logger.info(f"Received exit signal {signal.name}...") - logger.info("Closing database connections") - logger.info("Nacking outstanding messages") + LOGGER.info(f"Received exit signal {signal.name}...") + LOGGER.info("Closing database connections") + LOGGER.info("Nacking outstanding messages") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] - logger.info(f"Cancelling {len(tasks)} outstanding tasks") + LOGGER.info(f"Cancelling {len(tasks)} outstanding tasks") await asyncio.gather(*tasks, return_exceptions=True) - logger.info(f'{"Flushing metrics"}') + LOGGER.info(f'{"Flushing metrics"}') loop.stop() signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) diff --git a/reflector_dataclasses.py b/reflector_dataclasses.py index daac519c..e05396ae 100644 --- a/reflector_dataclasses.py +++ b/reflector_dataclasses.py @@ -1,3 +1,8 @@ +""" +Collection of data classes for streamlining and rigidly structuring +the input and output parameters of functions +""" + import datetime from dataclasses import dataclass from typing import List @@ -7,6 +12,10 @@ import av @dataclass class TitleSummaryInput: + """ + Data class for the input to generate title and summaries. + The outcome will be used to send query to the LLM for processing. + """ input_text = str transcribed_time = float prompt = str @@ -15,23 +24,28 @@ class TitleSummaryInput: def __init__(self, transcribed_time, input_text=""): self.input_text = input_text self.transcribed_time = transcribed_time - self.prompt = f""" - ### Human: - Create a JSON object as response. The JSON object must have 2 fields: - i) title and ii) summary. For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. + self.prompt = \ + f""" + ### Human: + Create a JSON object as response.The JSON object must have 2 fields: + i) title and ii) summary.For the title field,generate a short title + for the given text. For the summary field, summarize the given text + in three sentences. - {self.input_text} + {self.input_text} - ### Assistant: - """ + ### Assistant: + """ self.data = {"data": self.prompt} self.headers = {"Content-Type": "application/json"} @dataclass class IncrementalResult: + """ + Data class for the result of generating one title and summaries. + Defines how a single "topic" looks like. + """ title = str description = str transcript = str @@ -44,6 +58,10 @@ class IncrementalResult: @dataclass class TitleSummaryOutput: + """ + Data class for the result of all generated titles and summaries. + The result will be sent back to the client + """ cmd = str topics = List[IncrementalResult] @@ -59,6 +77,10 @@ class TitleSummaryOutput: @dataclass class ParseLLMResult: + """ + Data class to parse the result returned by the LLM while generating title + and summaries. The result will be sent back to the client. + """ description = str transcript = str timestamp = str @@ -66,7 +88,8 @@ class ParseLLMResult: def __init__(self, param: TitleSummaryInput, output: dict): self.transcript = param.input_text self.description = output.pop("summary") - self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) + self.timestamp = \ + str(datetime.timedelta(seconds=round(param.transcribed_time))) def get_result(self): return { @@ -78,6 +101,10 @@ class ParseLLMResult: @dataclass class TranscriptionInput: + """ + Data class to define the input to the transcription function + AudioFrames -> input + """ frames = List[av.audio.frame.AudioFrame] def __init__(self, frames): @@ -86,6 +113,10 @@ class TranscriptionInput: @dataclass class TranscriptionOutput: + """ + Dataclass to define the result of the transcription function. + The result will be sent back to the client + """ cmd = str result_text = str @@ -102,6 +133,10 @@ class TranscriptionOutput: @dataclass class FinalSummaryResult: + """ + Dataclass to define the result of the final summary function. + The result will be sent back to the client. + """ cmd = str final_summary = str duration = str @@ -117,3 +152,13 @@ class FinalSummaryResult: "duration": self.duration, "summary": self.final_summary } + + +class BlackListedMessages: + """ + Class to hold the blacklisted messages. These messages should be filtered + out and not sent back to the client as part of the transcription. + """ + messages = [" Thank you.", " See you next time!", + " Thank you for watching!", " Bye!", + " And that's what I'm talking about."] diff --git a/server.py b/server.py index 2ba07229..f45148cd 100644 --- a/server.py +++ b/server.py @@ -6,7 +6,7 @@ import os import uuid import wave from concurrent.futures import ThreadPoolExecutor -from typing import Any, NoReturn +from typing import Union, NoReturn import aiohttp_cors import av @@ -15,13 +15,13 @@ from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay from faster_whisper import WhisperModel -from loguru import logger from sortedcontainers import SortedDict from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\ TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\ - TranscriptionOutput -from utils.run_utils import config, run_in_executor + TranscriptionOutput, BlackListedMessages +from utils.run_utils import CONFIG, run_in_executor +from utils.log_utils import LOGGER pcs = set() relay = MediaRelay() @@ -36,24 +36,24 @@ audio_buffer = av.AudioFifo() executor = ThreadPoolExecutor() transcription_text = "" last_transcribed_time = 0.0 -LLM_MACHINE_IP = config["LLM"]["LLM_MACHINE_IP"] -LLM_MACHINE_PORT = config["LLM"]["LLM_MACHINE_PORT"] +LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"] +LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"] LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" incremental_responses = [] sorted_transcripts = SortedDict() -def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: +def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]: try: output = json.loads(response.json()["results"][0]["text"]) return ParseLLMResult(param, output) except Exception as e: - logger.info("Exception" + str(e)) + LOGGER.info("Exception" + str(e)) return None -def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]: - logger.info("Generating title and summary") +def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: + LOGGER.info("Generating title and summary") # TODO : Handle unexpected output formats from the model try: @@ -66,12 +66,12 @@ def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOut incremental_responses.append(result) return TitleSummaryOutput(incremental_responses) except Exception as e: - logger.info("Exception" + str(e)) + LOGGER.info("Exception" + str(e)) return None def channel_log(channel, t: str, message: str) -> NoReturn: - logger.info("channel(%s) %s %s" % (channel.label, t, message)) + LOGGER.info("channel(%s) %s %s" % (channel.label, t, message)) def channel_send(channel, message: str) -> NoReturn: @@ -79,7 +79,7 @@ def channel_send(channel, message: str) -> NoReturn: channel.send(message) -def channel_send_increment(channel, param: Any[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: +def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: if channel and param: message = param.get_result() channel.send(json.dumps(message)) @@ -89,11 +89,11 @@ def channel_send_transcript(channel) -> NoReturn: # channel_log(channel, ">", message) if channel: try: - least_time = sorted_transcripts.keys()[0] + least_time = next(iter(sorted_transcripts)) message = sorted_transcripts[least_time].get_result() if message: del sorted_transcripts[least_time] - if message["text"] not in blacklisted_messages: + if message["text"] not in BlackListedMessages.messages: channel.send(json.dumps(message)) # Due to exceptions if one of the earlier batches can't return # a transcript, we don't want to be stuck waiting for the result @@ -101,22 +101,21 @@ def channel_send_transcript(channel) -> NoReturn: else: if len(sorted_transcripts) >= 3: del sorted_transcripts[least_time] - except Exception as e: - logger.info("Exception", str(e)) - pass + except Exception as exception: + LOGGER.info("Exception", str(exception)) -def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]: - logger.info("Transcribing..") - sorted_transcripts[input_frames[0].time] = None +def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: + LOGGER.info("Transcribing..") + sorted_transcripts[input_frames.frames[0].time] = None # TODO: Find cleaner way, watch "no transcription" issue below # Passing IO objects instead of temporary files throws an error - # Passing ndarrays (typecasted with float) does not give any + # Passing ndarray (type casted with float) does not give any # transcription. Refer issue, # https://github.com/guillaumekln/faster-whisper/issues/369 - audiofilename = "test" + str(datetime.datetime.now()) - wf = wave.open(audiofilename, "wb") + audio_file = "test" + str(datetime.datetime.now()) + wf = wave.open(audio_file, "wb") wf.setnchannels(CHANNELS) wf.setframerate(RATE) wf.setsampwidth(2) @@ -129,12 +128,12 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti try: segments, _ = \ - model.transcribe(audiofilename, + model.transcribe(audio_file, language="en", beam_size=5, vad_filter=True, - vad_parameters=dict(min_silence_duration_ms=500)) - os.remove(audiofilename) + vad_parameters={"min_silence_duration_ms": 500}) + os.remove(audio_file) segments = list(segments) result_text = "" duration = 0.0 @@ -152,9 +151,8 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti last_transcribed_time += duration transcription_text += result_text - except Exception as e: - logger.info("Exception" + str(e)) - pass + except Exception as exception: + LOGGER.info("Exception" + str(exception)) result = TranscriptionOutput(result_text) sorted_transcripts[input_frames.frames[0].time] = result @@ -162,6 +160,11 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti def get_final_summary_response() -> FinalSummaryResult: + """ + Collate the incremental summaries generated so far and return as the final + summary + :return: + """ final_summary = "" # Collate inc summaries @@ -170,8 +173,9 @@ def get_final_summary_response() -> FinalSummaryResult: response = FinalSummaryResult(final_summary, last_transcribed_time) - with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: - f.write(json.dumps(incremental_responses)) + with open("./artefacts/meeting_titles_and_summaries.txt", "a", + encoding="utf-8") as file: + file.write(json.dumps(incremental_responses)) return response @@ -222,6 +226,11 @@ class AudioStreamTrack(MediaStreamTrack): async def offer(request: requests.Request) -> web.Response: + """ + Establish the WebRTC connection with the client + :param request: + :return: + """ params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -230,7 +239,7 @@ async def offer(request: requests.Request) -> web.Response: pcs.add(pc) def log_info(msg, *args) -> NoReturn: - logger.info(pc_id + " " + msg, *args) + LOGGER.info(pc_id + " " + msg, *args) log_info("Created for " + request.remote) @@ -272,15 +281,17 @@ async def offer(request: requests.Request) -> web.Response: return web.Response( content_type="application/json", text=json.dumps( - {"sdp": pc.localDescription.sdp, - "type": pc.localDescription.type} + { + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type + } ), ) -async def on_shutdown(app) -> NoReturn: - coros = [pc.close() for pc in pcs] - await asyncio.gather(*coros) +async def on_shutdown(application: web.Application) -> NoReturn: + coroutines = [pc.close() for pc in pcs] + await asyncio.gather(*coroutines) pcs.clear() diff --git a/stream_client.py b/stream_client.py index ae22c6db..57f88229 100644 --- a/stream_client.py +++ b/stream_client.py @@ -9,8 +9,8 @@ import stamina from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc.contrib.media import (MediaPlayer, MediaRelay) -from utils.log_utils import logger -from utils.run_utils import config +from utils.log_utils import LOGGER +from utils.run_utils import CONFIG class StreamClient: @@ -35,7 +35,7 @@ class StreamClient: self.time_start = None self.queue = asyncio.Queue() self.player = MediaPlayer( - ':' + str(config['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), + ':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), format='avfoundation', options={'channels': '2'}) @@ -74,7 +74,7 @@ class StreamClient: self.pcs.add(pc) def log_info(msg, *args): - logger.info(pc_id + " " + msg, *args) + LOGGER.info(pc_id + " " + msg, *args) @pc.on("connectionstatechange") async def on_connectionstatechange(): diff --git a/trials/finetuning/youtube_scraping.py b/trials/finetuning/youtube_scraping.py index b0892f47..be8b7e41 100644 --- a/trials/finetuning/youtube_scraping.py +++ b/trials/finetuning/youtube_scraping.py @@ -93,6 +93,6 @@ def generate_finetuning_dataset(video_ids): video_ids = ["yTnSEZIwnkU"] dataset = generate_finetuning_dataset(video_ids) -with open("finetuning_dataset.jsonl", "w") as f: +with open("finetuning_dataset.jsonl", "w", encoding="utf-8") as file: for example in dataset: - f.write(json.dumps(example) + "\n") + file.write(json.dumps(example) + "\n") diff --git a/trials/server/server_multithreaded.py b/trials/server/server_multithreaded.py index 4f7688a0..6739fbf6 100644 --- a/trials/server/server_multithreaded.py +++ b/trials/server/server_multithreaded.py @@ -16,10 +16,10 @@ from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline -from reflector.utils.log_utils import logger -from reflector.utils.run_utils import config, Mutex +from reflector.utils.log_utils import LOGGER +from reflector.utils.run_utils import CONFIG, Mutex -WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"] +WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"] pcs = set() relay = MediaRelay() data_channel = None @@ -127,7 +127,7 @@ async def offer(request: requests.Request): pcs.add(pc) def log_info(msg: str, *args): - logger.info(pc_id + " " + msg, *args) + LOGGER.info(pc_id + " " + msg, *args) log_info("Created for " + request.remote) diff --git a/trials/title_summary/incsum.py b/trials/title_summary/incsum.py index 5081d16c..571af77f 100644 --- a/trials/title_summary/incsum.py +++ b/trials/title_summary/incsum.py @@ -3,14 +3,14 @@ import sys # Observe the incremental summaries by performing summaries in chunks -with open("transcript.txt") as f: - transcription = f.read() +with open("transcript.txt", "r", encoding="utf-8") as file: + transcription = file.read() def split_text_file(filename, token_count): nlp = spacy.load('en_core_web_md') - with open(filename, 'r') as file: + with open(filename, 'r', encoding="utf-8") as file: text = file.read() doc = nlp(text) @@ -36,9 +36,9 @@ chunks = split_text_file("transcript.txt", MAX_CHUNK_LENGTH) print("Number of chunks", len(chunks)) # Write chunks to file to refer to input vs output, separated by blank lines -with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a") as f: +with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a", encoding="utf-8") as file: for c in chunks: - f.write(c + "\n\n") + file.write(c + "\n\n") # If we want to run only a certain model, type the option while running # ex. python incsum.py 1 => will run approach 1 @@ -78,9 +78,9 @@ if index == "1" or index is None: summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summaries.append(summary) - with open("bart-summaries.txt", "a") as f: + with open("bart-summaries.txt", "a", encoding="utf-8") as file: for summary in summaries: - f.write(summary + "\n\n") + file.write(summary + "\n\n") # Approach 2 if index == "2" or index is None: @@ -114,8 +114,8 @@ if index == "2" or index is None: summary_ids = output[0, input_length:] summary = tokenizer.decode(summary_ids, skip_special_tokens=True) summaries.append(summary) - with open("gptneo1.3B-summaries.txt", "a") as f: - f.write(summary + "\n\n") + with open("gptneo1.3B-summaries.txt", "a", encoding="utf-8") as file: + file.write(summary + "\n\n") # Approach 3 if index == "3" or index is None: @@ -152,6 +152,6 @@ if index == "3" or index is None: skip_special_tokens=True) summaries.append(summary) - with open("mpt-7b-summaries.txt", "a") as f: + with open("mpt-7b-summaries.txt", "a", encoding="utf-8") as file: for summary in summaries: - f.write(summary + "\n\n") + file.write(summary + "\n\n") diff --git a/trials/whisper-jax/whisjax.py b/trials/whisper-jax/whisjax.py index 2926fce0..fb4f5e1f 100644 --- a/trials/whisper-jax/whisjax.py +++ b/trials/whisper-jax/whisjax.py @@ -19,15 +19,15 @@ import yt_dlp as youtube_dl from whisper_jax import FlaxWhisperPipline from ...utils.file_utils import download_files, upload_files -from ...utils.log_utils import logger -from ...utils.run_utils import config +from ...utils.log_utils import LOGGER +from ...utils.run_utils import CONFIG from ...utils.text_utils import post_process_transcription, summarize from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) -WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"] NOW = datetime.now() if not os.path.exists('../../artefacts'): @@ -75,7 +75,7 @@ def main(): # Download the lowest resolution YouTube video # (since we're just interested in the audio). # It will be saved to the current directory. - logger.info("Downloading YouTube video at url: " + args.location) + LOGGER.info("Downloading YouTube video at url: " + args.location) # Create options for the download ydl_opts = { @@ -93,12 +93,12 @@ def main(): ydl.download([args.location]) media_file = "../artefacts/audio.mp3" - logger.info("Saved downloaded YouTube video to: " + media_file) + LOGGER.info("Saved downloaded YouTube video to: " + media_file) else: # XXX - Download file using urllib, check if file is # audio/video using python-magic - logger.info(f"Downloading file at url: {args.location}") - logger.info(" XXX - This method hasn't been implemented yet.") + LOGGER.info(f"Downloading file at url: {args.location}") + LOGGER.info(" XXX - This method hasn't been implemented yet.") elif url.scheme == '': media_file = url.path # If file is not present locally, take it from S3 bucket @@ -119,7 +119,7 @@ def main(): audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name video.audio.write_audiofile(audio_filename, logger=None) - logger.info(f"Extracting audio to: {audio_filename}") + LOGGER.info(f"Extracting audio to: {audio_filename}") # Handle audio only file except Exception: audio = moviepy.editor.AudioFileClip(media_file) @@ -129,14 +129,14 @@ def main(): else: audio_filename = media_file - logger.info("Finished extracting audio") - logger.info("Transcribing") + LOGGER.info("Finished extracting audio") + LOGGER.info("Transcribing") # Convert the audio to text using the OpenAI Whisper model pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, batch_size=16) whisper_result = pipeline(audio_filename, return_timestamps=True) - logger.info("Finished transcribing file") + LOGGER.info("Finished transcribing file") whisper_result = post_process_transcription(whisper_result) @@ -153,10 +153,10 @@ def main(): "w") as transcript_file_timestamps: transcript_file_timestamps.write(str(whisper_result)) - logger.info("Creating word cloud") + LOGGER.info("Creating word cloud") create_wordcloud(NOW) - logger.info("Performing talk-diff and talk-diff visualization") + LOGGER.info("Performing talk-diff and talk-diff visualization") create_talk_diff_scatter_viz(NOW) # S3 : Push artefacts to S3 bucket @@ -172,7 +172,7 @@ def main(): summarize(transcript_text, NOW, False, False) - logger.info("Summarization completed") + LOGGER.info("Summarization completed") # Summarization takes a lot of time, so do this separately at the end files_to_upload = [prefix + "summary_" + suffix + ".txt"] diff --git a/trials/whisper-jax/whisjax_realtime.py b/trials/whisper-jax/whisjax_realtime.py index 5f269c18..ec822854 100644 --- a/trials/whisper-jax/whisjax_realtime.py +++ b/trials/whisper-jax/whisjax_realtime.py @@ -11,12 +11,12 @@ from termcolor import colored from whisper_jax import FlaxWhisperPipline from ...utils.file_utils import upload_files -from ...utils.log_utils import logger -from ...utils.run_utils import config +from ...utils.log_utils import LOGGER +from ...utils.run_utils import CONFIG from ...utils.text_utils import post_process_transcription, summarize from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud -WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] +WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"] FRAMES_PER_BUFFER = 8000 FORMAT = pyaudio.paInt16 @@ -31,7 +31,7 @@ def main(): AUDIO_DEVICE_ID = -1 for i in range(p.get_device_count()): if p.get_device_info_by_index(i)["name"] == \ - config["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: + CONFIG["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: AUDIO_DEVICE_ID = i audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) stream = p.open( @@ -44,7 +44,7 @@ def main(): ) pipeline = FlaxWhisperPipline("openai/whisper-" + - config["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"], + CONFIG["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"], dtype=jnp.float16, batch_size=16) @@ -106,23 +106,26 @@ def main(): " | Transcribed duration: " + str(duration), "yellow")) - except Exception as e: - print(e) + except Exception as exception: + print(str(exception)) finally: - with open("real_time_transcript_" + - NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: - f.write(transcription) + with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + + ".txt", "w", encoding="utf-8") as file: + file.write(transcription) + with open("real_time_transcript_with_timestamp_" + - NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w", + encoding="utf-8") as file: transcript_with_timestamp["text"] = transcription - f.write(str(transcript_with_timestamp)) + file.write(str(transcript_with_timestamp)) - transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) + transcript_with_timestamp = \ + post_process_transcription(transcript_with_timestamp) - logger.info("Creating word cloud") + LOGGER.info("Creating word cloud") create_wordcloud(NOW, True) - logger.info("Performing talk-diff and talk-diff visualization") + LOGGER.info("Performing talk-diff and talk-diff visualization") create_talk_diff_scatter_viz(NOW, True) # S3 : Push artefacts to S3 bucket @@ -137,7 +140,7 @@ def main(): summarize(transcript_with_timestamp["text"], NOW, True, True) - logger.info("Summarization completed") + LOGGER.info("Summarization completed") # Summarization takes a lot of time, so do this separately at the end files_to_upload = ["real_time_summary_" + suffix + ".txt"] diff --git a/utils/file_utils.py b/utils/file_utils.py index 596bfe7f..db294c6e 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -1,16 +1,21 @@ +""" +Utility file for file handling related functions, including file downloads and +uploads to cloud storage +""" + import sys import boto3 import botocore -from .log_utils import logger -from .run_utils import config +from .log_utils import LOGGER +from .run_utils import CONFIG -BUCKET_NAME = config["AWS"]["BUCKET_NAME"] +BUCKET_NAME = CONFIG["AWS"]["BUCKET_NAME"] s3 = boto3.client('s3', - aws_access_key_id=config["AWS"]["AWS_ACCESS_KEY"], - aws_secret_access_key=config["AWS"]["AWS_SECRET_KEY"]) + aws_access_key_id=CONFIG["AWS"]["AWS_ACCESS_KEY"], + aws_secret_access_key=CONFIG["AWS"]["AWS_SECRET_KEY"]) def upload_files(files_to_upload): @@ -19,12 +24,12 @@ def upload_files(files_to_upload): :param files_to_upload: List of files to upload :return: None """ - for KEY in files_to_upload: - logger.info("Uploading file " + KEY) + for key in files_to_upload: + LOGGER.info("Uploading file " + key) try: - s3.upload_file(KEY, BUCKET_NAME, KEY) - except botocore.exceptions.ClientError as e: - print(e.response) + s3.upload_file(key, BUCKET_NAME, key) + except botocore.exceptions.ClientError as exception: + print(exception.response) def download_files(files_to_download): @@ -33,12 +38,12 @@ def download_files(files_to_download): :param files_to_download: List of files to download :return: None """ - for KEY in files_to_download: - logger.info("Downloading file " + KEY) + for key in files_to_download: + LOGGER.info("Downloading file " + key) try: - s3.download_file(BUCKET_NAME, KEY, KEY) - except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] == "404": + s3.download_file(BUCKET_NAME, key, key) + except botocore.exceptions.ClientError as exception: + if exception.response['Error']['Code'] == "404": print("The object does not exist.") else: raise diff --git a/utils/format_output.py b/utils/format_output.py index 4f026ce2..c46b90ba 100644 --- a/utils/format_output.py +++ b/utils/format_output.py @@ -1,13 +1,24 @@ +""" +Utility function to format the artefacts created during Reflector run +""" + import json -with open("../artefacts/meeting_titles_and_summaries.txt", "r") as f: +with open("../artefacts/meeting_titles_and_summaries.txt", "r", + encoding='utf-8') as f: outputs = f.read() outputs = json.loads(outputs) -transcript_file = open("../artefacts/meeting_transcript.txt", "a") -title_desc_file = open("../artefacts/meeting_title_description.txt", "a") -summary_file = open("../artefacts/meeting_summary.txt", "a") +transcript_file = open("../artefacts/meeting_transcript.txt", + "a", + encoding='utf-8') +title_desc_file = open("../artefacts/meeting_title_description.txt", + "a", + encoding='utf-8') +summary_file = open("../artefacts/meeting_summary.txt", + "a", + encoding='utf-8') for item in outputs["topics"]: transcript_file.write(item["transcript"]) diff --git a/utils/log_utils.py b/utils/log_utils.py index f665f5da..84cbe3fe 100644 --- a/utils/log_utils.py +++ b/utils/log_utils.py @@ -1,7 +1,15 @@ +""" +Utility file for logging +""" + import loguru class SingletonLogger: + """ + Use Singleton design pattern to create a logger object and share it + across the entire project + """ __instance = None @staticmethod @@ -15,4 +23,4 @@ class SingletonLogger: return SingletonLogger.__instance -logger = SingletonLogger.get_logger() +LOGGER = SingletonLogger.get_logger() diff --git a/utils/run_utils.py b/utils/run_utils.py index bb2b6348..2271fc19 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -1,3 +1,7 @@ +""" +Utility file for server side asynchronous task running and config objects +""" + import asyncio import configparser import contextlib @@ -7,6 +11,9 @@ from typing import ContextManager, Generic, TypeVar class ReflectorConfig: + """ + Create a single config object to share across the project + """ __config = None @staticmethod @@ -17,7 +24,7 @@ class ReflectorConfig: return ReflectorConfig.__config -config = ReflectorConfig.get_config() +CONFIG = ReflectorConfig.get_config() def run_in_executor(func, *args, executor=None, **kwargs): diff --git a/utils/text_utils.py b/utils/text_utils.py index 9ce7b1c1..8fb5ba10 100644 --- a/utils/text_utils.py +++ b/utils/text_utils.py @@ -1,3 +1,7 @@ +""" +Utility file for all text processing related functionalities +""" + import nltk import torch from nltk.corpus import stopwords @@ -6,8 +10,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import BartForConditionalGeneration, BartTokenizer -from log_utils import logger -from run_utils import config +from log_utils import LOGGER +from run_utils import CONFIG nltk.download('punkt', quiet=True) @@ -32,6 +36,12 @@ def compute_similarity(sent1, sent2): def remove_almost_alike_sentences(sentences, threshold=0.7): + """ + Filter sentences that are similar beyond a set threshold + :param sentences: + :param threshold: + :return: + """ num_sentences = len(sentences) removed_indices = set() @@ -62,6 +72,11 @@ def remove_almost_alike_sentences(sentences, threshold=0.7): def remove_outright_duplicate_sentences_from_chunk(chunk): + """ + Remove repetitive sentences + :param chunk: + :return: + """ chunk_text = chunk["text"] sentences = nltk.sent_tokenize(chunk_text) nonduplicate_sentences = list(dict.fromkeys(sentences)) @@ -69,6 +84,12 @@ def remove_outright_duplicate_sentences_from_chunk(chunk): def remove_whisper_repetitive_hallucination(nonduplicate_sentences): + """ + Remove sentences that are repeated as a result of Whisper + hallucinations + :param nonduplicate_sentences: + :return: + """ chunk_sentences = [] for sent in nonduplicate_sentences: @@ -91,6 +112,11 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences): def post_process_transcription(whisper_result): + """ + Parent function to perform post-processing on the transcription result + :param whisper_result: + :return: + """ transcript_text = "" for chunk in whisper_result["chunks"]: nonduplicate_sentences = \ @@ -121,9 +147,9 @@ def summarize_chunks(chunks, tokenizer, model): with torch.no_grad(): summary_ids = \ model.generate(input_ids, - num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]), + num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]), + max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]), early_stopping=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) @@ -132,7 +158,7 @@ def summarize_chunks(chunks, tokenizer, model): def chunk_text(text, - max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])): + max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])): """ Split text into smaller chunks. :param text: Text to be chunked @@ -154,14 +180,22 @@ def chunk_text(text, def summarize(transcript_text, timestamp, real_time=False, - chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): + chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): + """ + Summarize the given text either as a whole or as chunks as needed + :param transcript_text: + :param timestamp: + :param real_time: + :param chunk_summarize: + :return: + """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"] + summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"] if not summary_model: summary_model = "facebook/bart-large-cnn" # Summarize the generated transcript using the BART model - logger.info(f"Loading BART model: {summary_model}") + LOGGER.info(f"Loading BART model: {summary_model}") tokenizer = BartTokenizer.from_pretrained(summary_model) model = BartForConditionalGeneration.from_pretrained(summary_model) model = model.to(device) @@ -171,7 +205,7 @@ def summarize(transcript_text, timestamp, output_file = "real_time_" + output_file if chunk_summarize != "YES": - max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) + max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) inputs = tokenizer. \ batch_encode_plus([transcript_text], truncation=True, padding='longest', @@ -180,8 +214,8 @@ def summarize(transcript_text, timestamp, inputs = inputs.to(device) with torch.no_grad(): - num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"]) - max_length = int(config["SUMMARIZER"]["MAX_LENGTH"]) + num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]) + max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]) summaries = model.generate(inputs['input_ids'], num_beams=num_beans, length_penalty=2.0, @@ -194,16 +228,16 @@ def summarize(transcript_text, timestamp, clean_up_tokenization_spaces=False) for summary in summaries] summary = " ".join(decoded_summaries) - with open("./artefacts/" + output_file, 'w') as f: - f.write(summary.strip() + "\n") + with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file: + file.write(summary.strip() + "\n") else: - logger.info("Breaking transcript into smaller chunks") + LOGGER.info("Breaking transcript into smaller chunks") chunks = chunk_text(transcript_text) - logger.info(f"Transcript broken into {len(chunks)} " + LOGGER.info(f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words") - logger.info(f"Writing summary text to: {output_file}") + LOGGER.info(f"Writing summary text to: {output_file}") with open(output_file, 'w') as f: summaries = summarize_chunks(chunks, tokenizer, model) for summary in summaries: diff --git a/utils/viz_utils.py b/utils/viz_utils.py index d7debd0c..498c7cf7 100644 --- a/utils/viz_utils.py +++ b/utils/viz_utils.py @@ -1,3 +1,7 @@ +""" +Utility file for all visualization related functions +""" + import ast import collections import os @@ -81,8 +85,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): else: filename = "./artefacts/transcript_with_timestamp_" + \ timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" - with open(filename) as f: - transcription_timestamp_text = f.read() + with open(filename) as file: + transcription_timestamp_text = file.read() res = ast.literal_eval(transcription_timestamp_text) chunks = res["chunks"]