diff --git a/server/reflector_dataclasses.py b/server/reflector_dataclasses.py index 459f5fd0..c417b857 100644 --- a/server/reflector_dataclasses.py +++ b/server/reflector_dataclasses.py @@ -6,6 +6,7 @@ the input and output parameters of functions import datetime from dataclasses import dataclass from typing import List +from sortedcontainers import SortedDict import av @@ -184,3 +185,19 @@ class BlackListedMessages: messages = [" Thank you.", " See you next time!", " Thank you for watching!", " Bye!", " And that's what I'm talking about."] + + +@dataclass +class TranscriptionContext: + transcription_text: str + last_transcribed_time: float + incremental_responses: List[IncrementalResult] + sorted_transcripts: dict + data_channel: None # FIXME + + def __init__(self): + self.transcription_text = "" + self.last_transcribed_time = 0.0 + self.incremental_responses = [] + self.data_channel = None + self.sorted_transcripts = SortedDict() diff --git a/server/server.py b/server/server.py index b8154fbd..c72e28b6 100644 --- a/server/server.py +++ b/server/server.py @@ -15,18 +15,16 @@ from aiohttp import web from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay from faster_whisper import WhisperModel -from sortedcontainers import SortedDict -from reflector_dataclasses import BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, \ - TitleSummaryOutput, TranscriptionInput, TranscriptionOutput +from reflector_dataclasses import ( + BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, + TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext) from utils.log_utils import LOGGER from utils.run_utils import CONFIG, run_in_executor, SECRETS # WebRTC components pcs = set() relay = MediaRelay() -data_channel = None -audio_buffer = av.AudioFifo() executor = ThreadPoolExecutor() # Transcription model @@ -37,22 +35,16 @@ model = WhisperModel("tiny", device="cpu", # Audio configurations CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"]) RATE = int(CONFIG["AUDIO"]["SAMPLING_RATE"]) - -# Global vars -transcription_text = "" -last_transcribed_time = 0.0 +AUDIO_BUFFER_SIZE = 256 * 960 # LLM -LLM_MACHINE_IP = SECRETS["LLM"]["LLM_MACHINE_IP"] -LLM_MACHINE_PORT = SECRETS["LLM"]["LLM_MACHINE_PORT"] -LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" - -# Topic and summary responses -incremental_responses = [] - -# To synchronize the thread results before returning to the client -sorted_transcripts = SortedDict() - +LLM_URL = os.environ.get("LLM_URL") +if LLM_URL: + LOGGER.info(f"Using LLM from environment: {LLM_URL}") +else: + LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"] + LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"] + LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]: """ @@ -69,7 +61,7 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U return None -def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: +def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: """ From the input provided (transcript), query the LLM to generate topics and summaries @@ -86,10 +78,10 @@ def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryO output = parse_llm_output(param, response) if output: result = output.get_result() - incremental_responses.append(result) - return TitleSummaryOutput(incremental_responses) - except Exception as e: - LOGGER.info("Exception" + str(e)) + ctx.incremental_responses.append(result) + return TitleSummaryOutput(ctx.incremental_responses) + except Exception: + LOGGER.exception("Exception while generating title and summary") return None @@ -127,32 +119,33 @@ def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummar channel.send(json.dumps(message)) -def channel_send_transcript(channel) -> NoReturn: +def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn: """ Send the transcription result via the data channel :param channel: :return: """ # channel_log(channel, ">", message) - if channel: - try: - least_time = next(iter(sorted_transcripts)) - message = sorted_transcripts[least_time].get_result() - if message: - del sorted_transcripts[least_time] - if message["text"] not in BlackListedMessages.messages: - channel.send(json.dumps(message)) - # Due to exceptions if one of the earlier batches can't return - # a transcript, we don't want to be stuck waiting for the result - # With the threshold size of 3, we pop the first(lost) element - else: - if len(sorted_transcripts) >= 3: - del sorted_transcripts[least_time] - except Exception as exception: - LOGGER.info("Exception", str(exception)) + if not ctx.data_channel: + return + try: + least_time = next(iter(ctx.sorted_transcripts)) + message = ctx.sorted_transcripts[least_time].get_result() + if message: + del ctx.sorted_transcripts[least_time] + if message["text"] not in BlackListedMessages.messages: + ctx.data_channel.send(json.dumps(message)) + # Due to exceptions if one of the earlier batches can't return + # a transcript, we don't want to be stuck waiting for the result + # With the threshold size of 3, we pop the first(lost) element + else: + if len(ctx.sorted_transcripts) >= 3: + del ctx.sorted_transcripts[least_time] + except Exception as exception: + LOGGER.info("Exception", str(exception)) -def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: +def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: """ From the collected audio frames create transcription by inferring from the chosen transcription model @@ -160,7 +153,7 @@ def get_transcription(input_frames: TranscriptionInput) -> Union[None, Transcrip :return: """ LOGGER.info("Transcribing..") - sorted_transcripts[input_frames.frames[0].time] = None + ctx.sorted_transcripts[input_frames.frames[0].time] = None # TODO: Find cleaner way, watch "no transcription" issue below # Passing IO objects instead of temporary files throws an error @@ -200,19 +193,18 @@ def get_transcription(input_frames: TranscriptionInput) -> Union[None, Transcrip end_time = 5.5 duration += (end_time - start_time) - global last_transcribed_time, transcription_text - last_transcribed_time += duration - transcription_text += result_text + ctx.last_transcribed_time += duration + ctx.transcription_text += result_text except Exception as exception: LOGGER.info("Exception" + str(exception)) result = TranscriptionOutput(result_text) - sorted_transcripts[input_frames.frames[0].time] = result + ctx.sorted_transcripts[input_frames.frames[0].time] = result return result -def get_final_summary_response() -> FinalSummaryResult: +def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult: """ Collate the incremental summaries generated so far and return as the final summary @@ -221,14 +213,14 @@ def get_final_summary_response() -> FinalSummaryResult: final_summary = "" # Collate inc summaries - for topic in incremental_responses: + for topic in ctx.incremental_responses: final_summary += topic["description"] - response = FinalSummaryResult(final_summary, last_transcribed_time) + response = FinalSummaryResult(final_summary, ctx.last_transcribed_time) with open("./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8") as file: - file.write(json.dumps(incremental_responses)) + file.write(json.dumps(ctx.incremental_responses)) return response @@ -240,37 +232,41 @@ class AudioStreamTrack(MediaStreamTrack): kind = "audio" - def __init__(self, track): + def __init__(self, ctx: TranscriptionContext, track): super().__init__() + self.ctx = ctx self.track = track + self.audio_buffer = av.AudioFifo() async def recv(self) -> av.audio.frame.AudioFrame: - global transcription_text + ctx = self.ctx frame = await self.track.recv() - audio_buffer.write(frame) + self.audio_buffer.write(frame) - if local_frames := audio_buffer.read_many(256 * 960, partial=False): + if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False): whisper_result = run_in_executor( get_transcription, + ctx, TranscriptionInput(local_frames), executor=executor ) whisper_result.add_done_callback( - lambda f: channel_send_transcript(data_channel) + lambda f: channel_send_transcript(ctx) if f.result() else None ) - if len(transcription_text) > 25: - llm_input_text = transcription_text - transcription_text = "" + if len(ctx.transcription_text) > 25: + llm_input_text = ctx.transcription_text + ctx.transcription_text = "" param = TitleSummaryInput(input_text=llm_input_text, - transcribed_time=last_transcribed_time) + transcribed_time=ctx.last_transcribed_time) llm_result = run_in_executor(get_title_and_summary, + ctx, param, executor=executor) llm_result.add_done_callback( - lambda f: channel_send_increment(data_channel, + lambda f: channel_send_increment(ctx.data_channel, llm_result.result()) if f.result() else None @@ -287,6 +283,7 @@ async def offer(request: requests.Request) -> web.Response: params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) + ctx = TranscriptionContext() pc = RTCPeerConnection() pc_id = "PeerConnection(%s)" % uuid.uuid4() pcs.add(pc) @@ -298,8 +295,7 @@ async def offer(request: requests.Request) -> web.Response: @pc.on("datachannel") def on_datachannel(channel) -> NoReturn: - global data_channel - data_channel = channel + ctx.data_channel = channel channel_log(channel, "-", "created by remote party") @channel.on("message") @@ -308,7 +304,7 @@ async def offer(request: requests.Request) -> web.Response: if json.loads(message)["cmd"] == "STOP": # Placeholder final summary response = get_final_summary_response() - channel_send_increment(data_channel, response) + channel_send_increment(channel, response) # To-do Add code to stop connection from server side here # But have to handshake with client once @@ -325,7 +321,7 @@ async def offer(request: requests.Request) -> web.Response: @pc.on("track") def on_track(track) -> NoReturn: log_info("Track " + track.kind + " received") - pc.addTrack(AudioStreamTrack(relay.subscribe(track))) + pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track))) await pc.setRemoteDescription(offer)