mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
pull from main
This commit is contained in:
@@ -1,211 +0,0 @@
|
||||
"""
|
||||
Collection of data classes for streamlining and rigidly structuring
|
||||
the input and output parameters of functions
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
import av
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummaryInput:
|
||||
"""
|
||||
Data class for the input to generate title and summaries.
|
||||
The outcome will be used to send query to the LLM for processing.
|
||||
"""
|
||||
|
||||
input_text = str
|
||||
transcribed_time = float
|
||||
prompt = str
|
||||
data = dict
|
||||
|
||||
def __init__(self, transcribed_time, input_text=""):
|
||||
self.input_text = input_text
|
||||
self.transcribed_time = transcribed_time
|
||||
self.prompt = f"""
|
||||
### Human:
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
|
||||
{self.input_text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
self.data = {"prompt": self.prompt}
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncrementalResult:
|
||||
"""
|
||||
Data class for the result of generating one title and summaries.
|
||||
Defines how a single "topic" looks like.
|
||||
"""
|
||||
|
||||
title = str
|
||||
description = str
|
||||
transcript = str
|
||||
timestamp = str
|
||||
|
||||
def __init__(self, title, desc, transcript, timestamp):
|
||||
self.title = title
|
||||
self.description = desc
|
||||
self.transcript = transcript
|
||||
self.timestamp = timestamp
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummaryOutput:
|
||||
"""
|
||||
Data class for the result of all generated titles and summaries.
|
||||
The result will be sent back to the client
|
||||
"""
|
||||
|
||||
cmd = str
|
||||
topics = List[IncrementalResult]
|
||||
|
||||
def __init__(self, inc_responses):
|
||||
self.topics = inc_responses
|
||||
self.cmd = "UPDATE_TOPICS"
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict for displaying the transcription
|
||||
:return:
|
||||
"""
|
||||
return {"cmd": self.cmd, "topics": self.topics}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseLLMResult:
|
||||
"""
|
||||
Data class to parse the result returned by the LLM while generating title
|
||||
and summaries. The result will be sent back to the client.
|
||||
"""
|
||||
|
||||
title = str
|
||||
description = str
|
||||
transcript = str
|
||||
timestamp = str
|
||||
|
||||
def __init__(self, param: TitleSummaryInput, output: dict):
|
||||
self.title = output["title"]
|
||||
self.transcript = param.input_text
|
||||
self.description = output.pop("summary")
|
||||
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict after parsing the response from LLM
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"transcript": self.transcript,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionInput:
|
||||
"""
|
||||
Data class to define the input to the transcription function
|
||||
AudioFrames -> input
|
||||
"""
|
||||
|
||||
frames = List[av.audio.frame.AudioFrame]
|
||||
|
||||
def __init__(self, frames):
|
||||
self.frames = frames
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionOutput:
|
||||
"""
|
||||
Dataclass to define the result of the transcription function.
|
||||
The result will be sent back to the client
|
||||
"""
|
||||
|
||||
cmd = str
|
||||
result_text = str
|
||||
|
||||
def __init__(self, result_text):
|
||||
self.cmd = "SHOW_TRANSCRIPTION"
|
||||
self.result_text = result_text
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict for displaying the transcription
|
||||
:return:
|
||||
"""
|
||||
return {"cmd": self.cmd, "text": self.result_text}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSummaryResult:
|
||||
"""
|
||||
Dataclass to define the result of the final summary function.
|
||||
The result will be sent back to the client.
|
||||
"""
|
||||
|
||||
cmd = str
|
||||
final_summary = str
|
||||
duration = str
|
||||
|
||||
def __init__(self, final_summary, time):
|
||||
self.duration = str(datetime.timedelta(seconds=round(time)))
|
||||
self.final_summary = final_summary
|
||||
self.cmd = "DISPLAY_FINAL_SUMMARY"
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict for displaying the final summary
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"cmd": self.cmd,
|
||||
"duration": self.duration,
|
||||
"summary": self.final_summary,
|
||||
}
|
||||
|
||||
|
||||
class BlackListedMessages:
|
||||
"""
|
||||
Class to hold the blacklisted messages. These messages should be filtered
|
||||
out and not sent back to the client as part of the transcription.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
" Thank you.",
|
||||
" See you next time!",
|
||||
" Thank you for watching!",
|
||||
" Bye!",
|
||||
" And that's what I'm talking about.",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionContext:
|
||||
transcription_text: str
|
||||
last_transcribed_time: float
|
||||
incremental_responses: List[IncrementalResult]
|
||||
sorted_transcripts: dict
|
||||
data_channel: None # FIXME
|
||||
logger: None
|
||||
status: str
|
||||
|
||||
def __init__(self, logger):
|
||||
self.transcription_text = ""
|
||||
self.last_transcribed_time = 0.0
|
||||
self.incremental_responses = []
|
||||
self.data_channel = None
|
||||
self.sorted_transcripts = SortedDict()
|
||||
self.status = "idle"
|
||||
self.logger = logger
|
||||
@@ -1,5 +1,6 @@
|
||||
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
|
||||
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
|
||||
39
server/reflector/processors/audio_file_writer.py
Normal file
39
server/reflector/processors/audio_file_writer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from reflector.processors.base import Processor
|
||||
import av
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class AudioFileWriterProcessor(Processor):
|
||||
"""
|
||||
Write audio frames to a file.
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, path: Path | str):
|
||||
super().__init__()
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.path = path
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if not self.out_container:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
|
||||
self.out_stream = self.out_container.add_stream(
|
||||
"pcm_s16le", rate=data.sample_rate
|
||||
)
|
||||
for packet in self.out_stream.encode(data):
|
||||
self.out_container.mux(packet)
|
||||
await self.emit(data)
|
||||
|
||||
async def _flush(self):
|
||||
if self.out_container:
|
||||
for packet in self.out_stream.encode(None):
|
||||
self.out_container.mux(packet)
|
||||
self.out_container.close()
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
@@ -3,7 +3,6 @@ from reflector.processors.types import AudioFile
|
||||
from time import monotonic_ns
|
||||
from uuid import uuid4
|
||||
import io
|
||||
import wave
|
||||
import av
|
||||
|
||||
|
||||
@@ -28,12 +27,16 @@ class AudioMergeProcessor(Processor):
|
||||
# create audio file
|
||||
uu = uuid4().hex
|
||||
fd = io.BytesIO()
|
||||
with wave.open(fd, "wb") as wf:
|
||||
wf.setnchannels(channels)
|
||||
wf.setsampwidth(sample_width)
|
||||
wf.setframerate(sample_rate)
|
||||
for frame in data:
|
||||
wf.writeframes(frame.to_ndarray().tobytes())
|
||||
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
fd.seek(0)
|
||||
|
||||
# emit audio file
|
||||
audiofile = AudioFile(
|
||||
@@ -44,4 +47,5 @@ class AudioMergeProcessor(Processor):
|
||||
sample_width=sample_width,
|
||||
timestamp=data[0].pts * data[0].time_base,
|
||||
)
|
||||
|
||||
await self.emit(audiofile)
|
||||
|
||||
@@ -1,381 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import wave
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import NoReturn, Union
|
||||
|
||||
import aiohttp_cors
|
||||
import av
|
||||
import requests
|
||||
from aiohttp import web
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaRelay
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from reflector.models import (
|
||||
BlackListedMessages,
|
||||
FinalSummaryResult,
|
||||
ParseLLMResult,
|
||||
TitleSummaryInput,
|
||||
TitleSummaryOutput,
|
||||
TranscriptionInput,
|
||||
TranscriptionOutput,
|
||||
TranscriptionContext,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.run_utils import run_in_executor
|
||||
from reflector.settings import settings
|
||||
|
||||
# WebRTC components
|
||||
pcs = set()
|
||||
relay = MediaRelay()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
# Transcription model
|
||||
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
|
||||
|
||||
# LLM
|
||||
LLM_URL = settings.LLM_URL
|
||||
if not LLM_URL:
|
||||
assert settings.LLM_BACKEND == "oobabooga"
|
||||
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
|
||||
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")
|
||||
|
||||
|
||||
def parse_llm_output(
|
||||
param: TitleSummaryInput, response: requests.Response
|
||||
) -> Union[None, ParseLLMResult]:
|
||||
"""
|
||||
Function to parse the LLM response
|
||||
:param param:
|
||||
:param response:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
output = json.loads(response.json()["results"][0]["text"])
|
||||
return ParseLLMResult(param, output)
|
||||
except Exception:
|
||||
logger.exception("Exception while parsing LLM output")
|
||||
return None
|
||||
|
||||
|
||||
def get_title_and_summary(
|
||||
ctx: TranscriptionContext, param: TitleSummaryInput
|
||||
) -> Union[None, TitleSummaryOutput]:
|
||||
"""
|
||||
From the input provided (transcript), query the LLM to generate
|
||||
topics and summaries
|
||||
:param param:
|
||||
:return:
|
||||
"""
|
||||
logger.info("Generating title and summary")
|
||||
|
||||
# TODO : Handle unexpected output formats from the model
|
||||
try:
|
||||
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
|
||||
output = parse_llm_output(param, response)
|
||||
if output:
|
||||
result = output.get_result()
|
||||
ctx.incremental_responses.append(result)
|
||||
return TitleSummaryOutput(ctx.incremental_responses)
|
||||
except Exception:
|
||||
logger.exception("Exception while generating title and summary")
|
||||
return None
|
||||
|
||||
|
||||
def channel_send(channel, message: str) -> NoReturn:
|
||||
"""
|
||||
Send text messages via the data channel
|
||||
:param channel:
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
if channel:
|
||||
channel.send(message)
|
||||
|
||||
|
||||
def channel_send_increment(
|
||||
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Send the incremental topics and summaries via the data channel
|
||||
:param channel:
|
||||
:param param:
|
||||
:return:
|
||||
"""
|
||||
if channel and param:
|
||||
message = param.get_result()
|
||||
channel.send(json.dumps(message))
|
||||
|
||||
|
||||
def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
|
||||
"""
|
||||
Send the transcription result via the data channel
|
||||
:param channel:
|
||||
:return:
|
||||
"""
|
||||
if not ctx.data_channel:
|
||||
return
|
||||
try:
|
||||
least_time = next(iter(ctx.sorted_transcripts))
|
||||
message = ctx.sorted_transcripts[least_time].get_result()
|
||||
if message:
|
||||
del ctx.sorted_transcripts[least_time]
|
||||
if message["text"] not in BlackListedMessages.messages:
|
||||
ctx.data_channel.send(json.dumps(message))
|
||||
# Due to exceptions if one of the earlier batches can't return
|
||||
# a transcript, we don't want to be stuck waiting for the result
|
||||
# With the threshold size of 3, we pop the first(lost) element
|
||||
else:
|
||||
if len(ctx.sorted_transcripts) >= 3:
|
||||
del ctx.sorted_transcripts[least_time]
|
||||
except Exception:
|
||||
logger.exception("Exception while sending transcript")
|
||||
|
||||
|
||||
def get_transcription(
|
||||
ctx: TranscriptionContext, input_frames: TranscriptionInput
|
||||
) -> Union[None, TranscriptionOutput]:
|
||||
"""
|
||||
From the collected audio frames create transcription by inferring from
|
||||
the chosen transcription model
|
||||
:param input_frames:
|
||||
:return:
|
||||
"""
|
||||
ctx.logger.info("Transcribing..")
|
||||
ctx.sorted_transcripts[input_frames.frames[0].time] = None
|
||||
|
||||
# TODO: Find cleaner way, watch "no transcription" issue below
|
||||
# Passing IO objects instead of temporary files throws an error
|
||||
# Passing ndarray (type casted with float) does not give any
|
||||
# transcription. Refer issue,
|
||||
# https://github.com/guillaumekln/faster-whisper/issues/369
|
||||
audio_file = "test" + str(datetime.datetime.now())
|
||||
wf = wave.open(audio_file, "wb")
|
||||
wf.setnchannels(settings.AUDIO_CHANNELS)
|
||||
wf.setframerate(settings.AUDIO_SAMPLING_RATE)
|
||||
wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH)
|
||||
|
||||
for frame in input_frames.frames:
|
||||
wf.writeframes(b"".join(frame.to_ndarray()))
|
||||
wf.close()
|
||||
|
||||
result_text = ""
|
||||
|
||||
try:
|
||||
segments, _ = model.transcribe(
|
||||
audio_file,
|
||||
language="en",
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
os.remove(audio_file)
|
||||
segments = list(segments)
|
||||
result_text = ""
|
||||
duration = 0.0
|
||||
for segment in segments:
|
||||
result_text += segment.text
|
||||
start_time = segment.start
|
||||
end_time = segment.end
|
||||
if not segment.start:
|
||||
start_time = 0.0
|
||||
if not segment.end:
|
||||
end_time = 5.5
|
||||
duration += end_time - start_time
|
||||
|
||||
ctx.last_transcribed_time += duration
|
||||
ctx.transcription_text += result_text
|
||||
|
||||
except Exception:
|
||||
logger.exception("Exception while transcribing")
|
||||
|
||||
result = TranscriptionOutput(result_text)
|
||||
ctx.sorted_transcripts[input_frames.frames[0].time] = result
|
||||
return result
|
||||
|
||||
|
||||
def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
|
||||
"""
|
||||
Collate the incremental summaries generated so far and return as the final
|
||||
summary
|
||||
:return:
|
||||
"""
|
||||
final_summary = ""
|
||||
|
||||
# Collate inc summaries
|
||||
for topic in ctx.incremental_responses:
|
||||
final_summary += topic["description"]
|
||||
|
||||
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
|
||||
|
||||
with open(
|
||||
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
|
||||
) as file:
|
||||
file.write(json.dumps(ctx.incremental_responses))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AudioStreamTrack(MediaStreamTrack):
|
||||
"""
|
||||
An audio stream track.
|
||||
"""
|
||||
|
||||
kind = "audio"
|
||||
|
||||
def __init__(self, ctx: TranscriptionContext, track):
|
||||
super().__init__()
|
||||
self.ctx = ctx
|
||||
self.track = track
|
||||
self.audio_buffer = av.AudioFifo()
|
||||
|
||||
async def recv(self) -> av.audio.frame.AudioFrame:
|
||||
ctx = self.ctx
|
||||
frame = await self.track.recv()
|
||||
self.audio_buffer.write(frame)
|
||||
|
||||
if local_frames := self.audio_buffer.read_many(
|
||||
settings.AUDIO_BUFFER_SIZE, partial=False
|
||||
):
|
||||
whisper_result = run_in_executor(
|
||||
get_transcription,
|
||||
ctx,
|
||||
TranscriptionInput(local_frames),
|
||||
executor=executor,
|
||||
)
|
||||
whisper_result.add_done_callback(
|
||||
lambda f: channel_send_transcript(ctx) if f.result() else None
|
||||
)
|
||||
|
||||
if len(ctx.transcription_text) > 25:
|
||||
llm_input_text = ctx.transcription_text
|
||||
ctx.transcription_text = ""
|
||||
param = TitleSummaryInput(
|
||||
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
|
||||
)
|
||||
llm_result = run_in_executor(
|
||||
get_title_and_summary, ctx, param, executor=executor
|
||||
)
|
||||
llm_result.add_done_callback(
|
||||
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
|
||||
if f.result()
|
||||
else None
|
||||
)
|
||||
return frame
|
||||
|
||||
|
||||
async def offer(request: requests.Request) -> web.Response:
|
||||
"""
|
||||
Establish the WebRTC connection with the client
|
||||
:param request:
|
||||
:return:
|
||||
"""
|
||||
params = await request.json()
|
||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||
|
||||
# client identification
|
||||
peername = request.transport.get_extra_info("peername")
|
||||
if peername is not None:
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
else:
|
||||
clientid = uuid.uuid4()
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
pcs.add(pc)
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel) -> NoReturn:
|
||||
ctx.data_channel = channel
|
||||
ctx.logger = ctx.logger.bind(channel=channel.label)
|
||||
ctx.logger.info("Channel created by remote party")
|
||||
|
||||
@channel.on("message")
|
||||
def on_message(message: str) -> NoReturn:
|
||||
ctx.logger.info(f"Message: {message}")
|
||||
if json.loads(message)["cmd"] == "STOP":
|
||||
# Placeholder final summary
|
||||
response = get_final_summary_response()
|
||||
channel_send_increment(channel, response)
|
||||
# To-do Add code to stop connection from server side here
|
||||
# But have to handshake with client once
|
||||
|
||||
if isinstance(message, str) and message.startswith("ping"):
|
||||
channel_send(channel, "pong" + message[4:])
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange() -> NoReturn:
|
||||
ctx.logger.info(f"Connection state changed: {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
pcs.discard(pc)
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track) -> NoReturn:
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track)))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
return web.Response(
|
||||
content_type="application/json",
|
||||
text=json.dumps(
|
||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def on_shutdown(application: web.Application) -> NoReturn:
|
||||
"""
|
||||
On shutdown, the coroutines that shutdown client connections are
|
||||
executed
|
||||
:param application:
|
||||
:return:
|
||||
"""
|
||||
coroutines = [pc.close() for pc in pcs]
|
||||
await asyncio.gather(*coroutines)
|
||||
pcs.clear()
|
||||
|
||||
|
||||
def create_app() -> web.Application:
|
||||
"""
|
||||
Create the web application
|
||||
"""
|
||||
app = web.Application()
|
||||
cors = aiohttp_cors.setup(
|
||||
app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
offer_resource = cors.add(app.router.add_resource("/offer"))
|
||||
cors.add(offer_resource.add_route("POST", offer))
|
||||
app.on_shutdown.append(on_shutdown)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=1250, help="Server port (def: 1250)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
app = create_app()
|
||||
web.run_app(app, access_log=None, host=args.host, port=args.port)
|
||||
@@ -9,6 +9,9 @@ class Settings(BaseSettings):
|
||||
# Database
|
||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||
|
||||
# local data directory (audio for no)
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
# Whisper
|
||||
WHISPER_MODEL_SIZE: str = "tiny"
|
||||
WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny"
|
||||
|
||||
@@ -2,17 +2,18 @@ import asyncio
|
||||
from fastapi import Request, APIRouter
|
||||
from reflector.events import subscribers_shutdown
|
||||
from pydantic import BaseModel
|
||||
from reflector.models import TranscriptionContext
|
||||
from reflector.logger import logger
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||
from json import loads, dumps
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
import av
|
||||
from reflector.processors import (
|
||||
Pipeline,
|
||||
AudioChunkerProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
@@ -25,6 +26,15 @@ sessions = []
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class TranscriptionContext(object):
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
self.pipeline = None
|
||||
self.data_channel = None
|
||||
self.status = "idle"
|
||||
self.topics = []
|
||||
|
||||
|
||||
class AudioStreamTrack(MediaStreamTrack):
|
||||
"""
|
||||
An audio stream track.
|
||||
@@ -64,7 +74,11 @@ class PipelineEvent(StrEnum):
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
params: RtcOffer, request: Request, event_callback=None, event_callback_args=None
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
event_callback=None,
|
||||
event_callback_args=None,
|
||||
audio_filename: Path | None = None,
|
||||
):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
@@ -73,7 +87,6 @@ async def rtc_offer_base(
|
||||
peername = request.client
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
ctx.topics = []
|
||||
|
||||
async def update_status(status: str):
|
||||
changed = ctx.status != status
|
||||
@@ -151,14 +164,18 @@ async def rtc_offer_base(
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
ctx.pipeline = Pipeline(
|
||||
processors = []
|
||||
if audio_filename is not None:
|
||||
processors += [AudioFileWriterProcessor(path=audio_filename)]
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||
)
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
# FIXME: warmup is not working well yet
|
||||
# await ctx.pipeline.warmup()
|
||||
|
||||
|
||||
@@ -5,14 +5,20 @@ from fastapi import (
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from fastapi_pagination import Page, paginate
|
||||
from reflector.logger import logger
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.settings import settings
|
||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
import av
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -81,6 +87,44 @@ class Transcript(BaseModel):
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_mp3(self):
|
||||
fn = self.audio_mp3_filename
|
||||
if fn.exists():
|
||||
return
|
||||
|
||||
logger.info(f"Converting audio to mp3: {self.audio_filename}")
|
||||
inp = av.open(self.audio_filename.as_posix(), "r")
|
||||
|
||||
# create temporary file for mp3
|
||||
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
|
||||
out = av.open(tmp.name, "w")
|
||||
stream = out.add_stream("mp3")
|
||||
for frame in inp.decode(audio=0):
|
||||
frame.pts = None
|
||||
for packet in stream.encode(frame):
|
||||
out.mux(packet)
|
||||
for packet in stream.encode(None):
|
||||
out.mux(packet)
|
||||
out.close()
|
||||
|
||||
# move temporary file to final location
|
||||
Path(tmp.name).rename(fn)
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_filename(self):
|
||||
return self.data_path / "audio.wav"
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self) -> list[Transcript]:
|
||||
@@ -112,6 +156,10 @@ class TranscriptController:
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(self, transcript_id: str) -> None:
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
@@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str):
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# TODO: Implement audio generation
|
||||
return HTTPException(status_code=500, detail="Not implemented")
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
return FileResponse(transcript.audio_filename, media_type="audio/wav")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||
async def transcript_get_audio_mp3(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
||||
|
||||
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
@@ -371,4 +435,5 @@ async def transcript_record_webrtc(
|
||||
request,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_filename,
|
||||
)
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_rtc_server(aiohttp_server, event_loop):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from reflector.server import create_app
|
||||
from reflector.stream_client import StreamClient
|
||||
from reflector.models import TitleSummaryOutput
|
||||
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
||||
|
||||
# customize settings to have a mock LLM server
|
||||
with patch("reflector.server.get_title_and_summary") as mock_llm:
|
||||
# any response from mock_llm will be test topic
|
||||
mock_llm.return_value = TitleSummaryOutput(["topic_test"])
|
||||
|
||||
# create the server
|
||||
app = create_app()
|
||||
server = await aiohttp_server(app)
|
||||
url = f"http://{server.host}:{server.port}/offer"
|
||||
|
||||
# create signaling
|
||||
parser = argparse.ArgumentParser()
|
||||
add_signaling_arguments(parser)
|
||||
args = parser.parse_args(["-s", "tcp-socket"])
|
||||
signaling = create_signaling(args)
|
||||
|
||||
# create the client
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
||||
await client.start()
|
||||
|
||||
# we just want the first transcription
|
||||
# and topic update messages
|
||||
|
||||
marks = {
|
||||
"SHOW_TRANSCRIPTION": False,
|
||||
"UPDATE_TOPICS": False,
|
||||
}
|
||||
|
||||
async for rawmsg in client.get_reader():
|
||||
msg = json.loads(rawmsg)
|
||||
cmd = msg["cmd"]
|
||||
if cmd == "SHOW_TRANSCRIPTION":
|
||||
assert "text" in msg
|
||||
assert "want to share my incredible experience" in msg["text"]
|
||||
elif cmd == "UPDATE_TOPICS":
|
||||
assert "topics" in msg
|
||||
assert "topic_test" in msg["topics"]
|
||||
marks[cmd] = True
|
||||
|
||||
# break if we have all the events we need
|
||||
if all(marks.values()):
|
||||
break
|
||||
|
||||
# stop the server
|
||||
await server.close()
|
||||
await client.stop()
|
||||
@@ -71,11 +71,15 @@ async def dummy_llm():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# start server
|
||||
host = "127.0.0.1"
|
||||
port = 1255
|
||||
@@ -189,3 +193,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
# check that audio is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/wav"
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mp3"
|
||||
|
||||
Reference in New Issue
Block a user