server: refactor to prevent using global variables

- allow LLM_URL to be passed directly by env, otherwise fallback to the current config.ini
- prevent usage of global, shared variables are now passed through a context
- can now have multiple meeting at the same time
This commit is contained in:
Mathieu Virbel
2023-07-27 11:54:12 +02:00
parent 4e67ba5782
commit 0e56d051bd
2 changed files with 78 additions and 65 deletions

View File

@@ -6,6 +6,7 @@ the input and output parameters of functions
import datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
from sortedcontainers import SortedDict
import av import av
@@ -184,3 +185,19 @@ class BlackListedMessages:
messages = [" Thank you.", " See you next time!", messages = [" Thank you.", " See you next time!",
" Thank you for watching!", " Bye!", " Thank you for watching!", " Bye!",
" And that's what I'm talking about."] " And that's what I'm talking about."]
@dataclass
class TranscriptionContext:
transcription_text: str
last_transcribed_time: float
incremental_responses: List[IncrementalResult]
sorted_transcripts: dict
data_channel: None # FIXME
def __init__(self):
self.transcription_text = ""
self.last_transcribed_time = 0.0
self.incremental_responses = []
self.data_channel = None
self.sorted_transcripts = SortedDict()

View File

@@ -15,18 +15,16 @@ from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from sortedcontainers import SortedDict
from reflector_dataclasses import BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, \ from reflector_dataclasses import (
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput,
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext)
from utils.log_utils import LOGGER from utils.log_utils import LOGGER
from utils.run_utils import CONFIG, run_in_executor, SECRETS from utils.run_utils import CONFIG, run_in_executor, SECRETS
# WebRTC components # WebRTC components
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
data_channel = None
audio_buffer = av.AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
# Transcription model # Transcription model
@@ -37,22 +35,16 @@ model = WhisperModel("tiny", device="cpu",
# Audio configurations # Audio configurations
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"]) CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
RATE = int(CONFIG["AUDIO"]["SAMPLING_RATE"]) RATE = int(CONFIG["AUDIO"]["SAMPLING_RATE"])
AUDIO_BUFFER_SIZE = 256 * 960
# Global vars
transcription_text = ""
last_transcribed_time = 0.0
# LLM # LLM
LLM_MACHINE_IP = SECRETS["LLM"]["LLM_MACHINE_IP"] LLM_URL = os.environ.get("LLM_URL")
LLM_MACHINE_PORT = SECRETS["LLM"]["LLM_MACHINE_PORT"] if LLM_URL:
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" LOGGER.info(f"Using LLM from environment: {LLM_URL}")
else:
# Topic and summary responses LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"]
incremental_responses = [] LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
# To synchronize the thread results before returning to the client
sorted_transcripts = SortedDict()
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]: def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
""" """
@@ -69,7 +61,7 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
return None return None
def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
""" """
From the input provided (transcript), query the LLM to generate From the input provided (transcript), query the LLM to generate
topics and summaries topics and summaries
@@ -86,10 +78,10 @@ def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryO
output = parse_llm_output(param, response) output = parse_llm_output(param, response)
if output: if output:
result = output.get_result() result = output.get_result()
incremental_responses.append(result) ctx.incremental_responses.append(result)
return TitleSummaryOutput(incremental_responses) return TitleSummaryOutput(ctx.incremental_responses)
except Exception as e: except Exception:
LOGGER.info("Exception" + str(e)) LOGGER.exception("Exception while generating title and summary")
return None return None
@@ -127,32 +119,33 @@ def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummar
channel.send(json.dumps(message)) channel.send(json.dumps(message))
def channel_send_transcript(channel) -> NoReturn: def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
""" """
Send the transcription result via the data channel Send the transcription result via the data channel
:param channel: :param channel:
:return: :return:
""" """
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel: if not ctx.data_channel:
try: return
least_time = next(iter(sorted_transcripts)) try:
message = sorted_transcripts[least_time].get_result() least_time = next(iter(ctx.sorted_transcripts))
if message: message = ctx.sorted_transcripts[least_time].get_result()
del sorted_transcripts[least_time] if message:
if message["text"] not in BlackListedMessages.messages: del ctx.sorted_transcripts[least_time]
channel.send(json.dumps(message)) if message["text"] not in BlackListedMessages.messages:
# Due to exceptions if one of the earlier batches can't return ctx.data_channel.send(json.dumps(message))
# a transcript, we don't want to be stuck waiting for the result # Due to exceptions if one of the earlier batches can't return
# With the threshold size of 3, we pop the first(lost) element # a transcript, we don't want to be stuck waiting for the result
else: # With the threshold size of 3, we pop the first(lost) element
if len(sorted_transcripts) >= 3: else:
del sorted_transcripts[least_time] if len(ctx.sorted_transcripts) >= 3:
except Exception as exception: del ctx.sorted_transcripts[least_time]
LOGGER.info("Exception", str(exception)) except Exception as exception:
LOGGER.info("Exception", str(exception))
def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
""" """
From the collected audio frames create transcription by inferring from From the collected audio frames create transcription by inferring from
the chosen transcription model the chosen transcription model
@@ -160,7 +153,7 @@ def get_transcription(input_frames: TranscriptionInput) -> Union[None, Transcrip
:return: :return:
""" """
LOGGER.info("Transcribing..") LOGGER.info("Transcribing..")
sorted_transcripts[input_frames.frames[0].time] = None ctx.sorted_transcripts[input_frames.frames[0].time] = None
# TODO: Find cleaner way, watch "no transcription" issue below # TODO: Find cleaner way, watch "no transcription" issue below
# Passing IO objects instead of temporary files throws an error # Passing IO objects instead of temporary files throws an error
@@ -200,19 +193,18 @@ def get_transcription(input_frames: TranscriptionInput) -> Union[None, Transcrip
end_time = 5.5 end_time = 5.5
duration += (end_time - start_time) duration += (end_time - start_time)
global last_transcribed_time, transcription_text ctx.last_transcribed_time += duration
last_transcribed_time += duration ctx.transcription_text += result_text
transcription_text += result_text
except Exception as exception: except Exception as exception:
LOGGER.info("Exception" + str(exception)) LOGGER.info("Exception" + str(exception))
result = TranscriptionOutput(result_text) result = TranscriptionOutput(result_text)
sorted_transcripts[input_frames.frames[0].time] = result ctx.sorted_transcripts[input_frames.frames[0].time] = result
return result return result
def get_final_summary_response() -> FinalSummaryResult: def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
""" """
Collate the incremental summaries generated so far and return as the final Collate the incremental summaries generated so far and return as the final
summary summary
@@ -221,14 +213,14 @@ def get_final_summary_response() -> FinalSummaryResult:
final_summary = "" final_summary = ""
# Collate inc summaries # Collate inc summaries
for topic in incremental_responses: for topic in ctx.incremental_responses:
final_summary += topic["description"] final_summary += topic["description"]
response = FinalSummaryResult(final_summary, last_transcribed_time) response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
with open("./artefacts/meeting_titles_and_summaries.txt", "a", with open("./artefacts/meeting_titles_and_summaries.txt", "a",
encoding="utf-8") as file: encoding="utf-8") as file:
file.write(json.dumps(incremental_responses)) file.write(json.dumps(ctx.incremental_responses))
return response return response
@@ -240,37 +232,41 @@ class AudioStreamTrack(MediaStreamTrack):
kind = "audio" kind = "audio"
def __init__(self, track): def __init__(self, ctx: TranscriptionContext, track):
super().__init__() super().__init__()
self.ctx = ctx
self.track = track self.track = track
self.audio_buffer = av.AudioFifo()
async def recv(self) -> av.audio.frame.AudioFrame: async def recv(self) -> av.audio.frame.AudioFrame:
global transcription_text ctx = self.ctx
frame = await self.track.recv() frame = await self.track.recv()
audio_buffer.write(frame) self.audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False): if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False):
whisper_result = run_in_executor( whisper_result = run_in_executor(
get_transcription, get_transcription,
ctx,
TranscriptionInput(local_frames), TranscriptionInput(local_frames),
executor=executor executor=executor
) )
whisper_result.add_done_callback( whisper_result.add_done_callback(
lambda f: channel_send_transcript(data_channel) lambda f: channel_send_transcript(ctx)
if f.result() if f.result()
else None else None
) )
if len(transcription_text) > 25: if len(ctx.transcription_text) > 25:
llm_input_text = transcription_text llm_input_text = ctx.transcription_text
transcription_text = "" ctx.transcription_text = ""
param = TitleSummaryInput(input_text=llm_input_text, param = TitleSummaryInput(input_text=llm_input_text,
transcribed_time=last_transcribed_time) transcribed_time=ctx.last_transcribed_time)
llm_result = run_in_executor(get_title_and_summary, llm_result = run_in_executor(get_title_and_summary,
ctx,
param, param,
executor=executor) executor=executor)
llm_result.add_done_callback( llm_result.add_done_callback(
lambda f: channel_send_increment(data_channel, lambda f: channel_send_increment(ctx.data_channel,
llm_result.result()) llm_result.result())
if f.result() if f.result()
else None else None
@@ -287,6 +283,7 @@ async def offer(request: requests.Request) -> web.Response:
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
ctx = TranscriptionContext()
pc = RTCPeerConnection() pc = RTCPeerConnection()
pc_id = "PeerConnection(%s)" % uuid.uuid4() pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc) pcs.add(pc)
@@ -298,8 +295,7 @@ async def offer(request: requests.Request) -> web.Response:
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel) -> NoReturn: def on_datachannel(channel) -> NoReturn:
global data_channel ctx.data_channel = channel
data_channel = channel
channel_log(channel, "-", "created by remote party") channel_log(channel, "-", "created by remote party")
@channel.on("message") @channel.on("message")
@@ -308,7 +304,7 @@ async def offer(request: requests.Request) -> web.Response:
if json.loads(message)["cmd"] == "STOP": if json.loads(message)["cmd"] == "STOP":
# Placeholder final summary # Placeholder final summary
response = get_final_summary_response() response = get_final_summary_response()
channel_send_increment(data_channel, response) channel_send_increment(channel, response)
# To-do Add code to stop connection from server side here # To-do Add code to stop connection from server side here
# But have to handshake with client once # But have to handshake with client once
@@ -325,7 +321,7 @@ async def offer(request: requests.Request) -> web.Response:
@pc.on("track") @pc.on("track")
def on_track(track) -> NoReturn: def on_track(track) -> NoReturn:
log_info("Track " + track.kind + " received") log_info("Track " + track.kind + " received")
pc.addTrack(AudioStreamTrack(relay.subscribe(track))) pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track)))
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)