mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: remove non-apified server
This commit is contained in:
@@ -1,211 +0,0 @@
|
|||||||
"""
|
|
||||||
Collection of data classes for streamlining and rigidly structuring
|
|
||||||
the input and output parameters of functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
from sortedcontainers import SortedDict
|
|
||||||
|
|
||||||
import av
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TitleSummaryInput:
|
|
||||||
"""
|
|
||||||
Data class for the input to generate title and summaries.
|
|
||||||
The outcome will be used to send query to the LLM for processing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_text = str
|
|
||||||
transcribed_time = float
|
|
||||||
prompt = str
|
|
||||||
data = dict
|
|
||||||
|
|
||||||
def __init__(self, transcribed_time, input_text=""):
|
|
||||||
self.input_text = input_text
|
|
||||||
self.transcribed_time = transcribed_time
|
|
||||||
self.prompt = f"""
|
|
||||||
### Human:
|
|
||||||
Create a JSON object as response.The JSON object must have 2 fields:
|
|
||||||
i) title and ii) summary.For the title field,generate a short title
|
|
||||||
for the given text. For the summary field, summarize the given text
|
|
||||||
in three sentences.
|
|
||||||
|
|
||||||
{self.input_text}
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
"""
|
|
||||||
self.data = {"prompt": self.prompt}
|
|
||||||
self.headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IncrementalResult:
|
|
||||||
"""
|
|
||||||
Data class for the result of generating one title and summaries.
|
|
||||||
Defines how a single "topic" looks like.
|
|
||||||
"""
|
|
||||||
|
|
||||||
title = str
|
|
||||||
description = str
|
|
||||||
transcript = str
|
|
||||||
timestamp = str
|
|
||||||
|
|
||||||
def __init__(self, title, desc, transcript, timestamp):
|
|
||||||
self.title = title
|
|
||||||
self.description = desc
|
|
||||||
self.transcript = transcript
|
|
||||||
self.timestamp = timestamp
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TitleSummaryOutput:
|
|
||||||
"""
|
|
||||||
Data class for the result of all generated titles and summaries.
|
|
||||||
The result will be sent back to the client
|
|
||||||
"""
|
|
||||||
|
|
||||||
cmd = str
|
|
||||||
topics = List[IncrementalResult]
|
|
||||||
|
|
||||||
def __init__(self, inc_responses):
|
|
||||||
self.topics = inc_responses
|
|
||||||
self.cmd = "UPDATE_TOPICS"
|
|
||||||
|
|
||||||
def get_result(self) -> dict:
|
|
||||||
"""
|
|
||||||
Return the result dict for displaying the transcription
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return {"cmd": self.cmd, "topics": self.topics}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ParseLLMResult:
|
|
||||||
"""
|
|
||||||
Data class to parse the result returned by the LLM while generating title
|
|
||||||
and summaries. The result will be sent back to the client.
|
|
||||||
"""
|
|
||||||
|
|
||||||
title = str
|
|
||||||
description = str
|
|
||||||
transcript = str
|
|
||||||
timestamp = str
|
|
||||||
|
|
||||||
def __init__(self, param: TitleSummaryInput, output: dict):
|
|
||||||
self.title = output["title"]
|
|
||||||
self.transcript = param.input_text
|
|
||||||
self.description = output.pop("summary")
|
|
||||||
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
|
|
||||||
|
|
||||||
def get_result(self) -> dict:
|
|
||||||
"""
|
|
||||||
Return the result dict after parsing the response from LLM
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"title": self.title,
|
|
||||||
"description": self.description,
|
|
||||||
"transcript": self.transcript,
|
|
||||||
"timestamp": self.timestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TranscriptionInput:
|
|
||||||
"""
|
|
||||||
Data class to define the input to the transcription function
|
|
||||||
AudioFrames -> input
|
|
||||||
"""
|
|
||||||
|
|
||||||
frames = List[av.audio.frame.AudioFrame]
|
|
||||||
|
|
||||||
def __init__(self, frames):
|
|
||||||
self.frames = frames
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TranscriptionOutput:
|
|
||||||
"""
|
|
||||||
Dataclass to define the result of the transcription function.
|
|
||||||
The result will be sent back to the client
|
|
||||||
"""
|
|
||||||
|
|
||||||
cmd = str
|
|
||||||
result_text = str
|
|
||||||
|
|
||||||
def __init__(self, result_text):
|
|
||||||
self.cmd = "SHOW_TRANSCRIPTION"
|
|
||||||
self.result_text = result_text
|
|
||||||
|
|
||||||
def get_result(self) -> dict:
|
|
||||||
"""
|
|
||||||
Return the result dict for displaying the transcription
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return {"cmd": self.cmd, "text": self.result_text}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FinalSummaryResult:
|
|
||||||
"""
|
|
||||||
Dataclass to define the result of the final summary function.
|
|
||||||
The result will be sent back to the client.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cmd = str
|
|
||||||
final_summary = str
|
|
||||||
duration = str
|
|
||||||
|
|
||||||
def __init__(self, final_summary, time):
|
|
||||||
self.duration = str(datetime.timedelta(seconds=round(time)))
|
|
||||||
self.final_summary = final_summary
|
|
||||||
self.cmd = "DISPLAY_FINAL_SUMMARY"
|
|
||||||
|
|
||||||
def get_result(self) -> dict:
|
|
||||||
"""
|
|
||||||
Return the result dict for displaying the final summary
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"cmd": self.cmd,
|
|
||||||
"duration": self.duration,
|
|
||||||
"summary": self.final_summary,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BlackListedMessages:
|
|
||||||
"""
|
|
||||||
Class to hold the blacklisted messages. These messages should be filtered
|
|
||||||
out and not sent back to the client as part of the transcription.
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
" Thank you.",
|
|
||||||
" See you next time!",
|
|
||||||
" Thank you for watching!",
|
|
||||||
" Bye!",
|
|
||||||
" And that's what I'm talking about.",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TranscriptionContext:
|
|
||||||
transcription_text: str
|
|
||||||
last_transcribed_time: float
|
|
||||||
incremental_responses: List[IncrementalResult]
|
|
||||||
sorted_transcripts: dict
|
|
||||||
data_channel: None # FIXME
|
|
||||||
logger: None
|
|
||||||
status: str
|
|
||||||
|
|
||||||
def __init__(self, logger):
|
|
||||||
self.transcription_text = ""
|
|
||||||
self.last_transcribed_time = 0.0
|
|
||||||
self.incremental_responses = []
|
|
||||||
self.data_channel = None
|
|
||||||
self.sorted_transcripts = SortedDict()
|
|
||||||
self.status = "idle"
|
|
||||||
self.logger = logger
|
|
||||||
@@ -1,381 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import wave
|
|
||||||
import uuid
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import NoReturn, Union
|
|
||||||
|
|
||||||
import aiohttp_cors
|
|
||||||
import av
|
|
||||||
import requests
|
|
||||||
from aiohttp import web
|
|
||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
|
||||||
from aiortc.contrib.media import MediaRelay
|
|
||||||
from faster_whisper import WhisperModel
|
|
||||||
|
|
||||||
from reflector.models import (
|
|
||||||
BlackListedMessages,
|
|
||||||
FinalSummaryResult,
|
|
||||||
ParseLLMResult,
|
|
||||||
TitleSummaryInput,
|
|
||||||
TitleSummaryOutput,
|
|
||||||
TranscriptionInput,
|
|
||||||
TranscriptionOutput,
|
|
||||||
TranscriptionContext,
|
|
||||||
)
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.utils.run_utils import run_in_executor
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
# WebRTC components
|
|
||||||
pcs = set()
|
|
||||||
relay = MediaRelay()
|
|
||||||
executor = ThreadPoolExecutor()
|
|
||||||
|
|
||||||
# Transcription model
|
|
||||||
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
LLM_URL = settings.LLM_URL
|
|
||||||
if not LLM_URL:
|
|
||||||
assert settings.LLM_BACKEND == "oobagooda"
|
|
||||||
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
|
|
||||||
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_llm_output(
|
|
||||||
param: TitleSummaryInput, response: requests.Response
|
|
||||||
) -> Union[None, ParseLLMResult]:
|
|
||||||
"""
|
|
||||||
Function to parse the LLM response
|
|
||||||
:param param:
|
|
||||||
:param response:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
output = json.loads(response.json()["results"][0]["text"])
|
|
||||||
return ParseLLMResult(param, output)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception while parsing LLM output")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_title_and_summary(
|
|
||||||
ctx: TranscriptionContext, param: TitleSummaryInput
|
|
||||||
) -> Union[None, TitleSummaryOutput]:
|
|
||||||
"""
|
|
||||||
From the input provided (transcript), query the LLM to generate
|
|
||||||
topics and summaries
|
|
||||||
:param param:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
logger.info("Generating title and summary")
|
|
||||||
|
|
||||||
# TODO : Handle unexpected output formats from the model
|
|
||||||
try:
|
|
||||||
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
|
|
||||||
output = parse_llm_output(param, response)
|
|
||||||
if output:
|
|
||||||
result = output.get_result()
|
|
||||||
ctx.incremental_responses.append(result)
|
|
||||||
return TitleSummaryOutput(ctx.incremental_responses)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception while generating title and summary")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def channel_send(channel, message: str) -> NoReturn:
|
|
||||||
"""
|
|
||||||
Send text messages via the data channel
|
|
||||||
:param channel:
|
|
||||||
:param message:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if channel:
|
|
||||||
channel.send(message)
|
|
||||||
|
|
||||||
|
|
||||||
def channel_send_increment(
|
|
||||||
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
|
|
||||||
) -> NoReturn:
|
|
||||||
"""
|
|
||||||
Send the incremental topics and summaries via the data channel
|
|
||||||
:param channel:
|
|
||||||
:param param:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if channel and param:
|
|
||||||
message = param.get_result()
|
|
||||||
channel.send(json.dumps(message))
|
|
||||||
|
|
||||||
|
|
||||||
def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
|
|
||||||
"""
|
|
||||||
Send the transcription result via the data channel
|
|
||||||
:param channel:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if not ctx.data_channel:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
least_time = next(iter(ctx.sorted_transcripts))
|
|
||||||
message = ctx.sorted_transcripts[least_time].get_result()
|
|
||||||
if message:
|
|
||||||
del ctx.sorted_transcripts[least_time]
|
|
||||||
if message["text"] not in BlackListedMessages.messages:
|
|
||||||
ctx.data_channel.send(json.dumps(message))
|
|
||||||
# Due to exceptions if one of the earlier batches can't return
|
|
||||||
# a transcript, we don't want to be stuck waiting for the result
|
|
||||||
# With the threshold size of 3, we pop the first(lost) element
|
|
||||||
else:
|
|
||||||
if len(ctx.sorted_transcripts) >= 3:
|
|
||||||
del ctx.sorted_transcripts[least_time]
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception while sending transcript")
|
|
||||||
|
|
||||||
|
|
||||||
def get_transcription(
|
|
||||||
ctx: TranscriptionContext, input_frames: TranscriptionInput
|
|
||||||
) -> Union[None, TranscriptionOutput]:
|
|
||||||
"""
|
|
||||||
From the collected audio frames create transcription by inferring from
|
|
||||||
the chosen transcription model
|
|
||||||
:param input_frames:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
ctx.logger.info("Transcribing..")
|
|
||||||
ctx.sorted_transcripts[input_frames.frames[0].time] = None
|
|
||||||
|
|
||||||
# TODO: Find cleaner way, watch "no transcription" issue below
|
|
||||||
# Passing IO objects instead of temporary files throws an error
|
|
||||||
# Passing ndarray (type casted with float) does not give any
|
|
||||||
# transcription. Refer issue,
|
|
||||||
# https://github.com/guillaumekln/faster-whisper/issues/369
|
|
||||||
audio_file = "test" + str(datetime.datetime.now())
|
|
||||||
wf = wave.open(audio_file, "wb")
|
|
||||||
wf.setnchannels(settings.AUDIO_CHANNELS)
|
|
||||||
wf.setframerate(settings.AUDIO_SAMPLING_RATE)
|
|
||||||
wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH)
|
|
||||||
|
|
||||||
for frame in input_frames.frames:
|
|
||||||
wf.writeframes(b"".join(frame.to_ndarray()))
|
|
||||||
wf.close()
|
|
||||||
|
|
||||||
result_text = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
segments, _ = model.transcribe(
|
|
||||||
audio_file,
|
|
||||||
language="en",
|
|
||||||
beam_size=5,
|
|
||||||
vad_filter=True,
|
|
||||||
vad_parameters={"min_silence_duration_ms": 500},
|
|
||||||
)
|
|
||||||
os.remove(audio_file)
|
|
||||||
segments = list(segments)
|
|
||||||
result_text = ""
|
|
||||||
duration = 0.0
|
|
||||||
for segment in segments:
|
|
||||||
result_text += segment.text
|
|
||||||
start_time = segment.start
|
|
||||||
end_time = segment.end
|
|
||||||
if not segment.start:
|
|
||||||
start_time = 0.0
|
|
||||||
if not segment.end:
|
|
||||||
end_time = 5.5
|
|
||||||
duration += end_time - start_time
|
|
||||||
|
|
||||||
ctx.last_transcribed_time += duration
|
|
||||||
ctx.transcription_text += result_text
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception while transcribing")
|
|
||||||
|
|
||||||
result = TranscriptionOutput(result_text)
|
|
||||||
ctx.sorted_transcripts[input_frames.frames[0].time] = result
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
|
|
||||||
"""
|
|
||||||
Collate the incremental summaries generated so far and return as the final
|
|
||||||
summary
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
final_summary = ""
|
|
||||||
|
|
||||||
# Collate inc summaries
|
|
||||||
for topic in ctx.incremental_responses:
|
|
||||||
final_summary += topic["description"]
|
|
||||||
|
|
||||||
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
|
|
||||||
|
|
||||||
with open(
|
|
||||||
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
|
|
||||||
) as file:
|
|
||||||
file.write(json.dumps(ctx.incremental_responses))
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class AudioStreamTrack(MediaStreamTrack):
|
|
||||||
"""
|
|
||||||
An audio stream track.
|
|
||||||
"""
|
|
||||||
|
|
||||||
kind = "audio"
|
|
||||||
|
|
||||||
def __init__(self, ctx: TranscriptionContext, track):
|
|
||||||
super().__init__()
|
|
||||||
self.ctx = ctx
|
|
||||||
self.track = track
|
|
||||||
self.audio_buffer = av.AudioFifo()
|
|
||||||
|
|
||||||
async def recv(self) -> av.audio.frame.AudioFrame:
|
|
||||||
ctx = self.ctx
|
|
||||||
frame = await self.track.recv()
|
|
||||||
self.audio_buffer.write(frame)
|
|
||||||
|
|
||||||
if local_frames := self.audio_buffer.read_many(
|
|
||||||
settings.AUDIO_BUFFER_SIZE, partial=False
|
|
||||||
):
|
|
||||||
whisper_result = run_in_executor(
|
|
||||||
get_transcription,
|
|
||||||
ctx,
|
|
||||||
TranscriptionInput(local_frames),
|
|
||||||
executor=executor,
|
|
||||||
)
|
|
||||||
whisper_result.add_done_callback(
|
|
||||||
lambda f: channel_send_transcript(ctx) if f.result() else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(ctx.transcription_text) > 25:
|
|
||||||
llm_input_text = ctx.transcription_text
|
|
||||||
ctx.transcription_text = ""
|
|
||||||
param = TitleSummaryInput(
|
|
||||||
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
|
|
||||||
)
|
|
||||||
llm_result = run_in_executor(
|
|
||||||
get_title_and_summary, ctx, param, executor=executor
|
|
||||||
)
|
|
||||||
llm_result.add_done_callback(
|
|
||||||
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
|
|
||||||
if f.result()
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
async def offer(request: requests.Request) -> web.Response:
|
|
||||||
"""
|
|
||||||
Establish the WebRTC connection with the client
|
|
||||||
:param request:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
params = await request.json()
|
|
||||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
|
||||||
|
|
||||||
# client identification
|
|
||||||
peername = request.transport.get_extra_info("peername")
|
|
||||||
if peername is not None:
|
|
||||||
clientid = f"{peername[0]}:{peername[1]}"
|
|
||||||
else:
|
|
||||||
clientid = uuid.uuid4()
|
|
||||||
|
|
||||||
# create a context for the whole rtc transaction
|
|
||||||
# add a customised logger to the context
|
|
||||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
|
||||||
|
|
||||||
# handle RTC peer connection
|
|
||||||
pc = RTCPeerConnection()
|
|
||||||
pcs.add(pc)
|
|
||||||
|
|
||||||
@pc.on("datachannel")
|
|
||||||
def on_datachannel(channel) -> NoReturn:
|
|
||||||
ctx.data_channel = channel
|
|
||||||
ctx.logger = ctx.logger.bind(channel=channel.label)
|
|
||||||
ctx.logger.info("Channel created by remote party")
|
|
||||||
|
|
||||||
@channel.on("message")
|
|
||||||
def on_message(message: str) -> NoReturn:
|
|
||||||
ctx.logger.info(f"Message: {message}")
|
|
||||||
if json.loads(message)["cmd"] == "STOP":
|
|
||||||
# Placeholder final summary
|
|
||||||
response = get_final_summary_response()
|
|
||||||
channel_send_increment(channel, response)
|
|
||||||
# To-do Add code to stop connection from server side here
|
|
||||||
# But have to handshake with client once
|
|
||||||
|
|
||||||
if isinstance(message, str) and message.startswith("ping"):
|
|
||||||
channel_send(channel, "pong" + message[4:])
|
|
||||||
|
|
||||||
@pc.on("connectionstatechange")
|
|
||||||
async def on_connectionstatechange() -> NoReturn:
|
|
||||||
ctx.logger.info(f"Connection state changed: {pc.connectionState}")
|
|
||||||
if pc.connectionState == "failed":
|
|
||||||
await pc.close()
|
|
||||||
pcs.discard(pc)
|
|
||||||
|
|
||||||
@pc.on("track")
|
|
||||||
def on_track(track) -> NoReturn:
|
|
||||||
ctx.logger.info(f"Track {track.kind} received")
|
|
||||||
pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track)))
|
|
||||||
|
|
||||||
await pc.setRemoteDescription(offer)
|
|
||||||
|
|
||||||
answer = await pc.createAnswer()
|
|
||||||
await pc.setLocalDescription(answer)
|
|
||||||
return web.Response(
|
|
||||||
content_type="application/json",
|
|
||||||
text=json.dumps(
|
|
||||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def on_shutdown(application: web.Application) -> NoReturn:
|
|
||||||
"""
|
|
||||||
On shutdown, the coroutines that shutdown client connections are
|
|
||||||
executed
|
|
||||||
:param application:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
coroutines = [pc.close() for pc in pcs]
|
|
||||||
await asyncio.gather(*coroutines)
|
|
||||||
pcs.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> web.Application:
|
|
||||||
"""
|
|
||||||
Create the web application
|
|
||||||
"""
|
|
||||||
app = web.Application()
|
|
||||||
cors = aiohttp_cors.setup(
|
|
||||||
app,
|
|
||||||
defaults={
|
|
||||||
"*": aiohttp_cors.ResourceOptions(
|
|
||||||
allow_credentials=True, expose_headers="*", allow_headers="*"
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
offer_resource = cors.add(app.router.add_resource("/offer"))
|
|
||||||
cors.add(offer_resource.add_route("POST", offer))
|
|
||||||
app.on_shutdown.append(on_shutdown)
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
|
|
||||||
parser.add_argument(
|
|
||||||
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port", type=int, default=1250, help="Server port (def: 1250)"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
app = create_app()
|
|
||||||
web.run_app(app, access_log=None, host=args.host, port=args.port)
|
|
||||||
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
from fastapi import Request, APIRouter
|
from fastapi import Request, APIRouter
|
||||||
from reflector.events import subscribers_shutdown
|
from reflector.events import subscribers_shutdown
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from reflector.models import TranscriptionContext
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||||
from json import loads, dumps
|
from json import loads, dumps
|
||||||
@@ -27,6 +26,15 @@ sessions = []
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionContext(object):
|
||||||
|
def __init__(self, logger):
|
||||||
|
self.logger = logger
|
||||||
|
self.pipeline = None
|
||||||
|
self.data_channel = None
|
||||||
|
self.status = "idle"
|
||||||
|
self.topics = []
|
||||||
|
|
||||||
|
|
||||||
class AudioStreamTrack(MediaStreamTrack):
|
class AudioStreamTrack(MediaStreamTrack):
|
||||||
"""
|
"""
|
||||||
An audio stream track.
|
An audio stream track.
|
||||||
@@ -79,7 +87,6 @@ async def rtc_offer_base(
|
|||||||
peername = request.client
|
peername = request.client
|
||||||
clientid = f"{peername[0]}:{peername[1]}"
|
clientid = f"{peername[0]}:{peername[1]}"
|
||||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||||
ctx.topics = []
|
|
||||||
|
|
||||||
async def update_status(status: str):
|
async def update_status(status: str):
|
||||||
changed = ctx.status != status
|
changed = ctx.status != status
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_basic_rtc_server(aiohttp_server, event_loop):
|
|
||||||
# goal is to start the server, and send rtc audio to it
|
|
||||||
# validate the events received
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from reflector.server import create_app
|
|
||||||
from reflector.stream_client import StreamClient
|
|
||||||
from reflector.models import TitleSummaryOutput
|
|
||||||
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
|
||||||
|
|
||||||
# customize settings to have a mock LLM server
|
|
||||||
with patch("reflector.server.get_title_and_summary") as mock_llm:
|
|
||||||
# any response from mock_llm will be test topic
|
|
||||||
mock_llm.return_value = TitleSummaryOutput(["topic_test"])
|
|
||||||
|
|
||||||
# create the server
|
|
||||||
app = create_app()
|
|
||||||
server = await aiohttp_server(app)
|
|
||||||
url = f"http://{server.host}:{server.port}/offer"
|
|
||||||
|
|
||||||
# create signaling
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
add_signaling_arguments(parser)
|
|
||||||
args = parser.parse_args(["-s", "tcp-socket"])
|
|
||||||
signaling = create_signaling(args)
|
|
||||||
|
|
||||||
# create the client
|
|
||||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
|
||||||
client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
|
||||||
await client.start()
|
|
||||||
|
|
||||||
# we just want the first transcription
|
|
||||||
# and topic update messages
|
|
||||||
|
|
||||||
marks = {
|
|
||||||
"SHOW_TRANSCRIPTION": False,
|
|
||||||
"UPDATE_TOPICS": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
async for rawmsg in client.get_reader():
|
|
||||||
msg = json.loads(rawmsg)
|
|
||||||
cmd = msg["cmd"]
|
|
||||||
if cmd == "SHOW_TRANSCRIPTION":
|
|
||||||
assert "text" in msg
|
|
||||||
assert "want to share my incredible experience" in msg["text"]
|
|
||||||
elif cmd == "UPDATE_TOPICS":
|
|
||||||
assert "topics" in msg
|
|
||||||
assert "topic_test" in msg["topics"]
|
|
||||||
marks[cmd] = True
|
|
||||||
|
|
||||||
# break if we have all the events we need
|
|
||||||
if all(marks.values()):
|
|
||||||
break
|
|
||||||
|
|
||||||
# stop the server
|
|
||||||
await server.close()
|
|
||||||
await client.stop()
|
|
||||||
Reference in New Issue
Block a user