diff --git a/server/reflector/models.py b/server/reflector/models.py deleted file mode 100644 index d1aaaa1e..00000000 --- a/server/reflector/models.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -Collection of data classes for streamlining and rigidly structuring -the input and output parameters of functions -""" - -import datetime -from dataclasses import dataclass -from typing import List -from sortedcontainers import SortedDict - -import av - - -@dataclass -class TitleSummaryInput: - """ - Data class for the input to generate title and summaries. - The outcome will be used to send query to the LLM for processing. - """ - - input_text = str - transcribed_time = float - prompt = str - data = dict - - def __init__(self, transcribed_time, input_text=""): - self.input_text = input_text - self.transcribed_time = transcribed_time - self.prompt = f""" - ### Human: - Create a JSON object as response.The JSON object must have 2 fields: - i) title and ii) summary.For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. - - {self.input_text} - - ### Assistant: - """ - self.data = {"prompt": self.prompt} - self.headers = {"Content-Type": "application/json"} - - -@dataclass -class IncrementalResult: - """ - Data class for the result of generating one title and summaries. - Defines how a single "topic" looks like. - """ - - title = str - description = str - transcript = str - timestamp = str - - def __init__(self, title, desc, transcript, timestamp): - self.title = title - self.description = desc - self.transcript = transcript - self.timestamp = timestamp - - -@dataclass -class TitleSummaryOutput: - """ - Data class for the result of all generated titles and summaries. - The result will be sent back to the client - """ - - cmd = str - topics = List[IncrementalResult] - - def __init__(self, inc_responses): - self.topics = inc_responses - self.cmd = "UPDATE_TOPICS" - - def get_result(self) -> dict: - """ - Return the result dict for displaying the transcription - :return: - """ - return {"cmd": self.cmd, "topics": self.topics} - - -@dataclass -class ParseLLMResult: - """ - Data class to parse the result returned by the LLM while generating title - and summaries. The result will be sent back to the client. - """ - - title = str - description = str - transcript = str - timestamp = str - - def __init__(self, param: TitleSummaryInput, output: dict): - self.title = output["title"] - self.transcript = param.input_text - self.description = output.pop("summary") - self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) - - def get_result(self) -> dict: - """ - Return the result dict after parsing the response from LLM - :return: - """ - return { - "title": self.title, - "description": self.description, - "transcript": self.transcript, - "timestamp": self.timestamp, - } - - -@dataclass -class TranscriptionInput: - """ - Data class to define the input to the transcription function - AudioFrames -> input - """ - - frames = List[av.audio.frame.AudioFrame] - - def __init__(self, frames): - self.frames = frames - - -@dataclass -class TranscriptionOutput: - """ - Dataclass to define the result of the transcription function. - The result will be sent back to the client - """ - - cmd = str - result_text = str - - def __init__(self, result_text): - self.cmd = "SHOW_TRANSCRIPTION" - self.result_text = result_text - - def get_result(self) -> dict: - """ - Return the result dict for displaying the transcription - :return: - """ - return {"cmd": self.cmd, "text": self.result_text} - - -@dataclass -class FinalSummaryResult: - """ - Dataclass to define the result of the final summary function. - The result will be sent back to the client. - """ - - cmd = str - final_summary = str - duration = str - - def __init__(self, final_summary, time): - self.duration = str(datetime.timedelta(seconds=round(time))) - self.final_summary = final_summary - self.cmd = "DISPLAY_FINAL_SUMMARY" - - def get_result(self) -> dict: - """ - Return the result dict for displaying the final summary - :return: - """ - return { - "cmd": self.cmd, - "duration": self.duration, - "summary": self.final_summary, - } - - -class BlackListedMessages: - """ - Class to hold the blacklisted messages. These messages should be filtered - out and not sent back to the client as part of the transcription. - """ - - messages = [ - " Thank you.", - " See you next time!", - " Thank you for watching!", - " Bye!", - " And that's what I'm talking about.", - ] - - -@dataclass -class TranscriptionContext: - transcription_text: str - last_transcribed_time: float - incremental_responses: List[IncrementalResult] - sorted_transcripts: dict - data_channel: None # FIXME - logger: None - status: str - - def __init__(self, logger): - self.transcription_text = "" - self.last_transcribed_time = 0.0 - self.incremental_responses = [] - self.data_channel = None - self.sorted_transcripts = SortedDict() - self.status = "idle" - self.logger = logger diff --git a/server/reflector/server.py b/server/reflector/server.py deleted file mode 100644 index 8e28b583..00000000 --- a/server/reflector/server.py +++ /dev/null @@ -1,381 +0,0 @@ -import argparse -import asyncio -import datetime -import json -import os -import wave -import uuid -from concurrent.futures import ThreadPoolExecutor -from typing import NoReturn, Union - -import aiohttp_cors -import av -import requests -from aiohttp import web -from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription -from aiortc.contrib.media import MediaRelay -from faster_whisper import WhisperModel - -from reflector.models import ( - BlackListedMessages, - FinalSummaryResult, - ParseLLMResult, - TitleSummaryInput, - TitleSummaryOutput, - TranscriptionInput, - TranscriptionOutput, - TranscriptionContext, -) -from reflector.logger import logger -from reflector.utils.run_utils import run_in_executor -from reflector.settings import settings - -# WebRTC components -pcs = set() -relay = MediaRelay() -executor = ThreadPoolExecutor() - -# Transcription model -model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12) - -# LLM -LLM_URL = settings.LLM_URL -if not LLM_URL: - assert settings.LLM_BACKEND == "oobagooda" - LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate" -logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}") - - -def parse_llm_output( - param: TitleSummaryInput, response: requests.Response -) -> Union[None, ParseLLMResult]: - """ - Function to parse the LLM response - :param param: - :param response: - :return: - """ - try: - output = json.loads(response.json()["results"][0]["text"]) - return ParseLLMResult(param, output) - except Exception: - logger.exception("Exception while parsing LLM output") - return None - - -def get_title_and_summary( - ctx: TranscriptionContext, param: TitleSummaryInput -) -> Union[None, TitleSummaryOutput]: - """ - From the input provided (transcript), query the LLM to generate - topics and summaries - :param param: - :return: - """ - logger.info("Generating title and summary") - - # TODO : Handle unexpected output formats from the model - try: - response = requests.post(LLM_URL, headers=param.headers, json=param.data) - output = parse_llm_output(param, response) - if output: - result = output.get_result() - ctx.incremental_responses.append(result) - return TitleSummaryOutput(ctx.incremental_responses) - except Exception: - logger.exception("Exception while generating title and summary") - return None - - -def channel_send(channel, message: str) -> NoReturn: - """ - Send text messages via the data channel - :param channel: - :param message: - :return: - """ - if channel: - channel.send(message) - - -def channel_send_increment( - channel, param: Union[FinalSummaryResult, TitleSummaryOutput] -) -> NoReturn: - """ - Send the incremental topics and summaries via the data channel - :param channel: - :param param: - :return: - """ - if channel and param: - message = param.get_result() - channel.send(json.dumps(message)) - - -def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn: - """ - Send the transcription result via the data channel - :param channel: - :return: - """ - if not ctx.data_channel: - return - try: - least_time = next(iter(ctx.sorted_transcripts)) - message = ctx.sorted_transcripts[least_time].get_result() - if message: - del ctx.sorted_transcripts[least_time] - if message["text"] not in BlackListedMessages.messages: - ctx.data_channel.send(json.dumps(message)) - # Due to exceptions if one of the earlier batches can't return - # a transcript, we don't want to be stuck waiting for the result - # With the threshold size of 3, we pop the first(lost) element - else: - if len(ctx.sorted_transcripts) >= 3: - del ctx.sorted_transcripts[least_time] - except Exception: - logger.exception("Exception while sending transcript") - - -def get_transcription( - ctx: TranscriptionContext, input_frames: TranscriptionInput -) -> Union[None, TranscriptionOutput]: - """ - From the collected audio frames create transcription by inferring from - the chosen transcription model - :param input_frames: - :return: - """ - ctx.logger.info("Transcribing..") - ctx.sorted_transcripts[input_frames.frames[0].time] = None - - # TODO: Find cleaner way, watch "no transcription" issue below - # Passing IO objects instead of temporary files throws an error - # Passing ndarray (type casted with float) does not give any - # transcription. Refer issue, - # https://github.com/guillaumekln/faster-whisper/issues/369 - audio_file = "test" + str(datetime.datetime.now()) - wf = wave.open(audio_file, "wb") - wf.setnchannels(settings.AUDIO_CHANNELS) - wf.setframerate(settings.AUDIO_SAMPLING_RATE) - wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH) - - for frame in input_frames.frames: - wf.writeframes(b"".join(frame.to_ndarray())) - wf.close() - - result_text = "" - - try: - segments, _ = model.transcribe( - audio_file, - language="en", - beam_size=5, - vad_filter=True, - vad_parameters={"min_silence_duration_ms": 500}, - ) - os.remove(audio_file) - segments = list(segments) - result_text = "" - duration = 0.0 - for segment in segments: - result_text += segment.text - start_time = segment.start - end_time = segment.end - if not segment.start: - start_time = 0.0 - if not segment.end: - end_time = 5.5 - duration += end_time - start_time - - ctx.last_transcribed_time += duration - ctx.transcription_text += result_text - - except Exception: - logger.exception("Exception while transcribing") - - result = TranscriptionOutput(result_text) - ctx.sorted_transcripts[input_frames.frames[0].time] = result - return result - - -def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult: - """ - Collate the incremental summaries generated so far and return as the final - summary - :return: - """ - final_summary = "" - - # Collate inc summaries - for topic in ctx.incremental_responses: - final_summary += topic["description"] - - response = FinalSummaryResult(final_summary, ctx.last_transcribed_time) - - with open( - "./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8" - ) as file: - file.write(json.dumps(ctx.incremental_responses)) - - return response - - -class AudioStreamTrack(MediaStreamTrack): - """ - An audio stream track. - """ - - kind = "audio" - - def __init__(self, ctx: TranscriptionContext, track): - super().__init__() - self.ctx = ctx - self.track = track - self.audio_buffer = av.AudioFifo() - - async def recv(self) -> av.audio.frame.AudioFrame: - ctx = self.ctx - frame = await self.track.recv() - self.audio_buffer.write(frame) - - if local_frames := self.audio_buffer.read_many( - settings.AUDIO_BUFFER_SIZE, partial=False - ): - whisper_result = run_in_executor( - get_transcription, - ctx, - TranscriptionInput(local_frames), - executor=executor, - ) - whisper_result.add_done_callback( - lambda f: channel_send_transcript(ctx) if f.result() else None - ) - - if len(ctx.transcription_text) > 25: - llm_input_text = ctx.transcription_text - ctx.transcription_text = "" - param = TitleSummaryInput( - input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time - ) - llm_result = run_in_executor( - get_title_and_summary, ctx, param, executor=executor - ) - llm_result.add_done_callback( - lambda f: channel_send_increment(ctx.data_channel, llm_result.result()) - if f.result() - else None - ) - return frame - - -async def offer(request: requests.Request) -> web.Response: - """ - Establish the WebRTC connection with the client - :param request: - :return: - """ - params = await request.json() - offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) - - # client identification - peername = request.transport.get_extra_info("peername") - if peername is not None: - clientid = f"{peername[0]}:{peername[1]}" - else: - clientid = uuid.uuid4() - - # create a context for the whole rtc transaction - # add a customised logger to the context - ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - - # handle RTC peer connection - pc = RTCPeerConnection() - pcs.add(pc) - - @pc.on("datachannel") - def on_datachannel(channel) -> NoReturn: - ctx.data_channel = channel - ctx.logger = ctx.logger.bind(channel=channel.label) - ctx.logger.info("Channel created by remote party") - - @channel.on("message") - def on_message(message: str) -> NoReturn: - ctx.logger.info(f"Message: {message}") - if json.loads(message)["cmd"] == "STOP": - # Placeholder final summary - response = get_final_summary_response() - channel_send_increment(channel, response) - # To-do Add code to stop connection from server side here - # But have to handshake with client once - - if isinstance(message, str) and message.startswith("ping"): - channel_send(channel, "pong" + message[4:]) - - @pc.on("connectionstatechange") - async def on_connectionstatechange() -> NoReturn: - ctx.logger.info(f"Connection state changed: {pc.connectionState}") - if pc.connectionState == "failed": - await pc.close() - pcs.discard(pc) - - @pc.on("track") - def on_track(track) -> NoReturn: - ctx.logger.info(f"Track {track.kind} received") - pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track))) - - await pc.setRemoteDescription(offer) - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - return web.Response( - content_type="application/json", - text=json.dumps( - {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} - ), - ) - - -async def on_shutdown(application: web.Application) -> NoReturn: - """ - On shutdown, the coroutines that shutdown client connections are - executed - :param application: - :return: - """ - coroutines = [pc.close() for pc in pcs] - await asyncio.gather(*coroutines) - pcs.clear() - - -def create_app() -> web.Application: - """ - Create the web application - """ - app = web.Application() - cors = aiohttp_cors.setup( - app, - defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, expose_headers="*", allow_headers="*" - ) - }, - ) - - offer_resource = cors.add(app.router.add_resource("/offer")) - cors.add(offer_resource.add_route("POST", offer)) - app.on_shutdown.append(on_shutdown) - return app - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="WebRTC based server for Reflector") - parser.add_argument( - "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=1250, help="Server port (def: 1250)" - ) - args = parser.parse_args() - app = create_app() - web.run_app(app, access_log=None, host=args.host, port=args.port) diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index c0944a82..f28eb021 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -2,7 +2,6 @@ import asyncio from fastapi import Request, APIRouter from reflector.events import subscribers_shutdown from pydantic import BaseModel -from reflector.models import TranscriptionContext from reflector.logger import logger from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from json import loads, dumps @@ -27,6 +26,15 @@ sessions = [] router = APIRouter() +class TranscriptionContext(object): + def __init__(self, logger): + self.logger = logger + self.pipeline = None + self.data_channel = None + self.status = "idle" + self.topics = [] + + class AudioStreamTrack(MediaStreamTrack): """ An audio stream track. @@ -79,7 +87,6 @@ async def rtc_offer_base( peername = request.client clientid = f"{peername[0]}:{peername[1]}" ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - ctx.topics = [] async def update_status(status: str): changed = ctx.status != status diff --git a/server/tests/test_basic_rtc.py b/server/tests/test_basic_rtc.py deleted file mode 100644 index 93f33648..00000000 --- a/server/tests/test_basic_rtc.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -from unittest.mock import patch - - -@pytest.mark.asyncio -async def test_basic_rtc_server(aiohttp_server, event_loop): - # goal is to start the server, and send rtc audio to it - # validate the events received - import argparse - import json - from pathlib import Path - from reflector.server import create_app - from reflector.stream_client import StreamClient - from reflector.models import TitleSummaryOutput - from aiortc.contrib.signaling import add_signaling_arguments, create_signaling - - # customize settings to have a mock LLM server - with patch("reflector.server.get_title_and_summary") as mock_llm: - # any response from mock_llm will be test topic - mock_llm.return_value = TitleSummaryOutput(["topic_test"]) - - # create the server - app = create_app() - server = await aiohttp_server(app) - url = f"http://{server.host}:{server.port}/offer" - - # create signaling - parser = argparse.ArgumentParser() - add_signaling_arguments(parser) - args = parser.parse_args(["-s", "tcp-socket"]) - signaling = create_signaling(args) - - # create the client - path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" - client = StreamClient(signaling, url=url, play_from=path.as_posix()) - await client.start() - - # we just want the first transcription - # and topic update messages - - marks = { - "SHOW_TRANSCRIPTION": False, - "UPDATE_TOPICS": False, - } - - async for rawmsg in client.get_reader(): - msg = json.loads(rawmsg) - cmd = msg["cmd"] - if cmd == "SHOW_TRANSCRIPTION": - assert "text" in msg - assert "want to share my incredible experience" in msg["text"] - elif cmd == "UPDATE_TOPICS": - assert "topics" in msg - assert "topic_test" in msg["topics"] - marks[cmd] = True - - # break if we have all the events we need - if all(marks.values()): - break - - # stop the server - await server.close() - await client.stop()