From b892fc056267af26606e886e27355049b69a7b8f Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Tue, 25 Jul 2023 22:55:17 +0530 Subject: [PATCH] add data classes and typing --- reflector_dataclasses.py | 119 +++++++++++++++++++++++++++++++++++++++ server.py | 98 +++++++++++++------------------- 2 files changed, 158 insertions(+), 59 deletions(-) create mode 100644 reflector_dataclasses.py diff --git a/reflector_dataclasses.py b/reflector_dataclasses.py new file mode 100644 index 00000000..cfc4da6a --- /dev/null +++ b/reflector_dataclasses.py @@ -0,0 +1,119 @@ +import datetime +from dataclasses import dataclass +from typing import List + +import av + + +@dataclass +class TitleSummaryInput: + input_text = str + transcribed_time = float + prompt = str + data = dict + + def __init__(self, transcribed_time, input_text=""): + self.input_text = input_text + self.transcribed_time = transcribed_time + self.prompt = f""" + ### Human: + Create a JSON object as response. The JSON object must have 2 fields: + i) title and ii) summary. For the title field,generate a short title + for the given text. For the summary field, summarize the given text + in three sentences. + + {self.input_text} + + ### Assistant: + """ + self.data = {"data": self.prompt} + self.headers = {"Content-Type": "application/json"} + + +@dataclass +class IncrementalResponse: + title = str + description = str + transcript = str + + def __init__(self, title, desc, transcript): + self.title = title + self.description = desc + self.transcript = transcript + + +@dataclass +class TitleSummaryOutput: + cmd = str + topics = List[IncrementalResponse] + + def __init__(self, inc_responses): + self.topics = inc_responses + + def get_response(self): + return { + "cmd": self.cmd, + "topics": self.topics + } + + +@dataclass +class ParseLLMResult: + description = str + transcript = str + timestamp = str + + def __init__(self, param: TitleSummaryInput, output: dict): + self.transcript = param.input_text + self.description = output.pop("summary") + self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) + + def get_result(self): + return { + "description": self.description, + "transcript": self.transcript, + "timestamp": self.timestamp + } + + +@dataclass +class TranscriptionInput: + frames = List[av.audio.frame.AudioFrame] + + def __init__(self, frames): + self.frames = frames + + +@dataclass +class TranscriptionOutput: + cmd = str + result_text = str + + def __init__(self, result_text): + self.cmd = "SHOW_TRANSCRIPTION" + self.result_text = result_text + + def get_response(self): + return { + "cmd": self.cmd, + "text": self.result_text + } + + +@dataclass +class FinalSummaryResponse: + cmd = str + final_summary = str + duration = str + + def __init__(self, final_summary, time): + self.duration = str(datetime.timedelta(seconds=round(time))) + self.final_summary = final_summary + self.cmd = "" + + def get_response(self): + return { + "cmd": self.cmd, + "duration": self.duration, + "summary": self.final_summary + } diff --git a/server.py b/server.py index 55066eef..1c556a93 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,7 @@ import os import uuid import wave from concurrent.futures import ThreadPoolExecutor +from typing import Any import aiohttp_cors import requests @@ -17,7 +18,9 @@ from faster_whisper import WhisperModel from loguru import logger from sortedcontainers import SortedDict -from utils.run_utils import run_in_executor, config +from reflector_dataclasses import FinalSummaryResponse, ParseLLMResult, TitleSummaryInput, TitleSummaryOutput, \ + TranscriptionInput, TranscriptionOutput +from utils.run_utils import config, run_in_executor pcs = set() relay = MediaRelay() @@ -43,49 +46,30 @@ blacklisted_messages = [" Thank you.", " See you next time!", " And that's what I'm talking about."] -def get_title_and_summary(llm_input_text, last_timestamp): +def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: + try: + output = json.loads(response.json()["results"][0]["text"]) + return ParseLLMResult(param, output).get_result() + except Exception as e: + logger.info("Exception" + str(e)) + return None + + +def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]: logger.info("Generating title and summary") - # output = llm.generate(prompt) - - # Use monadical-ml to fire this query to an LLM and get result - headers = { - "Content-Type": "application/json" - } - - prompt = f""" - ### Human: - Create a JSON object as response. The JSON object must have 2 fields: - i) title and ii) summary. For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. - - {llm_input_text} - - ### Assistant: - """ - - data = { - "prompt": prompt - } # TODO : Handle unexpected output formats from the model try: - response = requests.post(LLM_URL, headers=headers, json=data) - output = json.loads(response.json()["results"][0]["text"]) - output["description"] = output.pop("summary") - output["transcript"] = llm_input_text - output["timestamp"] = \ - str(datetime.timedelta(seconds=round(last_timestamp))) - incremental_responses.append(output) - result = { - "cmd": "UPDATE_TOPICS", - "topics": incremental_responses, - } - + response = requests.post(LLM_URL, + headers=param.headers, + json=param.data) + output = parse_llm_output(param, response) + if output: + incremental_responses.append(output) + return TitleSummaryOutput(incremental_responses).get_response() except Exception as e: logger.info("Exception" + str(e)) - result = None - return result + return None def channel_log(channel, t, message): @@ -123,11 +107,11 @@ def channel_send_transcript(channel): pass -def get_transcription(frames): +def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]: logger.info("Transcribing..") - sorted_transcripts[frames[0].time] = None + sorted_transcripts[input_frames[0].time] = None - # TODO: + # TODO: Find cleaner way, watch "no transcription" issue below # Passing IO objects instead of temporary files throws an error # Passing ndarrays (typecasted with float) does not give any # transcription. Refer issue, @@ -138,7 +122,7 @@ def get_transcription(frames): wf.setframerate(RATE) wf.setsampwidth(2) - for frame in frames: + for frame in input_frames.frames: wf.writeframes(b"".join(frame.to_ndarray())) wf.close() @@ -173,30 +157,23 @@ def get_transcription(frames): logger.info("Exception" + str(e)) pass - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": result_text - } - sorted_transcripts[frames[0].time] = result + result = TranscriptionOutput(result_text).get_response() + sorted_transcripts[input_frames.frames[0].time] = result return result -def get_final_summary_response(): +def get_final_summary_response() -> Any[None, FinalSummaryResponse]: final_summary = "" # Collate inc summaries for topic in incremental_responses: final_summary += topic["description"] - response = { - "cmd": "DISPLAY_FINAL_SUMMARY", - "duration": str(datetime.timedelta( - seconds=round(last_transcribed_time))), - "summary": final_summary - } + response = FinalSummaryResponse(final_summary, last_transcribed_time).get_response() with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: f.write(json.dumps(incremental_responses)) + return response @@ -218,7 +195,9 @@ class AudioStreamTrack(MediaStreamTrack): if local_frames := audio_buffer.read_many(256 * 960, partial=False): whisper_result = run_in_executor( - get_transcription, local_frames, executor=executor + get_transcription, + TranscriptionInput(local_frames), + executor=executor ) whisper_result.add_done_callback( lambda f: channel_send_transcript(data_channel) @@ -226,12 +205,13 @@ class AudioStreamTrack(MediaStreamTrack): else None ) - if len(transcription_text) > 750: + if len(transcription_text) > 25: llm_input_text = transcription_text transcription_text = "" + param = TitleSummaryInput(input_text=llm_input_text, + transcribed_time=last_transcribed_time) llm_result = run_in_executor(get_title_and_summary, - llm_input_text, - last_transcribed_time, + param, executor=executor) llm_result.add_done_callback( lambda f: channel_send_increment(data_channel, @@ -332,4 +312,4 @@ if __name__ == "__main__": offer_resource = cors.add(app.router.add_resource("/offer")) cors.add(offer_resource.add_route("POST", offer)) app.on_shutdown.append(on_shutdown) - web.run_app(app, access_log=None, host=args.host, port=args.port) + web.run_app(app, access_log=None, host=args.host, port=args.port)