mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge branch 'server-api' of https://github.com/Monadical-SAS/reflector into server-api
This commit is contained in:
36
server/poetry.lock
generated
36
server/poetry.lock
generated
@@ -2199,6 +2199,26 @@ files = [
|
||||
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stamina"
|
||||
version = "23.1.0"
|
||||
description = "Production-grade retries made easy."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "stamina-23.1.0-py3-none-any.whl", hash = "sha256:850de8c2c2469aabf42a4c02e7372eaa12c2eced78f2bfa34162b8676c2846e5"},
|
||||
{file = "stamina-23.1.0.tar.gz", hash = "sha256:b16ce3d52d658aa75db813fc6a6661b770abfea915f72cda48e325f2a7854786"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tenacity = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["nox", "prometheus-client", "stamina[tests,typing]", "structlog", "tomli"]
|
||||
docs = ["furo", "myst-parser", "prometheus-client", "sphinx", "sphinx-notfound-page", "structlog"]
|
||||
tests = ["pytest", "pytest-asyncio"]
|
||||
typing = ["mypy (>=1.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.27.0"
|
||||
@@ -2247,6 +2267,20 @@ files = [
|
||||
[package.dependencies]
|
||||
mpmath = ">=0.19"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "8.2.2"
|
||||
description = "Retry code until it succeeds"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "tenacity-8.2.2-py3-none-any.whl", hash = "sha256:2f277afb21b851637e8f52e6a613ff08734c347dc19ade928e519d7d2d8569b0"},
|
||||
{file = "tenacity-8.2.2.tar.gz", hash = "sha256:43af037822bd0029025877f3b2d97cc4d7bb0c2991000a3d59d71517c5c969e0"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
doc = ["reno", "sphinx", "tornado (>=4.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.3"
|
||||
@@ -2732,4 +2766,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "0c81bfdc623dc7a55ac16a0948bfb5b2d9391abd32bad0e665b0251169c7f7de"
|
||||
content-hash = "75afc46634677cd9afdf2ae66b320a8eaaa36d360d0ba187e5974b90810df44f"
|
||||
|
||||
@@ -27,6 +27,7 @@ fastapi-pagination = "^0.12.6"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.7.0"
|
||||
stamina = "^23.1.0"
|
||||
|
||||
|
||||
[tool.poetry.group.client.dependencies]
|
||||
|
||||
@@ -199,6 +199,7 @@ class TranscriptionContext:
|
||||
sorted_transcripts: dict
|
||||
data_channel: None # FIXME
|
||||
logger: None
|
||||
status: str
|
||||
|
||||
def __init__(self, logger):
|
||||
self.transcription_text = ""
|
||||
@@ -206,4 +207,5 @@ class TranscriptionContext:
|
||||
self.incremental_responses = []
|
||||
self.data_channel = None
|
||||
self.sorted_transcripts = SortedDict()
|
||||
self.status = "idle"
|
||||
self.logger = logger
|
||||
|
||||
@@ -14,6 +14,7 @@ class Processor:
|
||||
if callback:
|
||||
self.on(callback)
|
||||
self.uid = uuid4().hex
|
||||
self.flushed = False
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
@@ -65,6 +66,7 @@ class Processor:
|
||||
"""
|
||||
# logger.debug(f"{self.__class__.__name__} push")
|
||||
try:
|
||||
self.flushed = False
|
||||
return await self._push(data)
|
||||
except Exception:
|
||||
self.logger.exception("Error in push")
|
||||
@@ -72,8 +74,12 @@ class Processor:
|
||||
async def flush(self):
|
||||
"""
|
||||
Flush data to this processor
|
||||
Works only one time, until another push is called
|
||||
"""
|
||||
if self.flushed:
|
||||
return
|
||||
# logger.debug(f"{self.__class__.__name__} flush")
|
||||
self.flushed = True
|
||||
return await self._flush()
|
||||
|
||||
def describe(self, level=0):
|
||||
|
||||
@@ -72,7 +72,7 @@ class StreamClient:
|
||||
async def on_connectionstatechange():
|
||||
self.logger.info(f"Connection state is {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
await self.stop()
|
||||
self.pcs.discard(pc)
|
||||
|
||||
@pc.on("track")
|
||||
@@ -87,7 +87,7 @@ class StreamClient:
|
||||
self.pc.addTrack(audio)
|
||||
self.track_audio = audio
|
||||
|
||||
channel = pc.createDataChannel("data-channel")
|
||||
self.channel = channel = pc.createDataChannel("data-channel")
|
||||
self.logger = self.logger.bind(channel=channel.label)
|
||||
self.logger.info("Created by local party")
|
||||
|
||||
|
||||
@@ -52,10 +52,15 @@ class RtcOffer(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineEvent(StrEnum):
|
||||
TRANSCRIPT = "TRANSCRIPT"
|
||||
TOPIC = "TOPIC"
|
||||
FINAL_SUMMARY = "FINAL_SUMMARY"
|
||||
STATUS = "STATUS"
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
@@ -70,15 +75,30 @@ async def rtc_offer_base(
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
ctx.topics = []
|
||||
|
||||
async def update_status(status: str):
|
||||
changed = ctx.status != status
|
||||
if changed:
|
||||
ctx.status = status
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.STATUS,
|
||||
args=event_callback_args,
|
||||
data=StrValue(value=status),
|
||||
)
|
||||
|
||||
# build pipeline callback
|
||||
async def on_transcript(transcript: Transcript):
|
||||
ctx.logger.info("Transcript", transcript=transcript)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TRANSCRIPT,
|
||||
@@ -86,9 +106,7 @@ async def rtc_offer_base(
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(
|
||||
summary: TitleSummary, event_callback=None, event_callback_args=None
|
||||
):
|
||||
async def on_topic(summary: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
ctx.topics.append(
|
||||
@@ -99,18 +117,23 @@ async def rtc_offer_base(
|
||||
"desc": summary.summary,
|
||||
}
|
||||
)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||
)
|
||||
|
||||
async def on_final_summary(
|
||||
summary: FinalSummary, event_callback=None, event_callback_args=None
|
||||
):
|
||||
async def on_final_summary(summary: FinalSummary):
|
||||
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||
"summary": summary.summary,
|
||||
@@ -118,9 +141,12 @@ async def rtc_offer_base(
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||
event=PipelineEvent.FINAL_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
@@ -137,11 +163,13 @@ async def rtc_offer_base(
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
async def flush_pipeline_and_quit():
|
||||
ctx.logger.info("Flushing pipeline")
|
||||
async def flush_pipeline_and_quit(close=True):
|
||||
await update_status("processing")
|
||||
await ctx.pipeline.flush()
|
||||
if close:
|
||||
ctx.logger.debug("Closing peer connection")
|
||||
await pc.close()
|
||||
await update_status("ended")
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
@@ -164,11 +192,14 @@ async def rtc_offer_base(
|
||||
ctx.logger.info(f"Connection state: {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
elif pc.connectionState == "closed":
|
||||
await flush_pipeline_and_quit(close=False)
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, track))
|
||||
asyncio.get_event_loop().create_task(update_status("recording"))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@ def generate_transcript_name():
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
title: str
|
||||
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptFinalSummary(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
|
||||
def add_event(self, event: str, data):
|
||||
self.events.append(TranscriptEvent(event=event, data=data))
|
||||
return {"event": event, "data": data}
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
self.events.append(ev)
|
||||
return ev
|
||||
|
||||
def upsert_topic(self, topic: TranscriptTopic):
|
||||
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
|
||||
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
|
||||
|
||||
# on first connection, send all events
|
||||
for event in transcript.events:
|
||||
await websocket.send_json(event.model_dump())
|
||||
await websocket.send_json(event.model_dump(mode="json"))
|
||||
|
||||
# XXX if transcript is final (locked=True and status=ended)
|
||||
# XXX send a final event to the client and close the connection
|
||||
@@ -254,24 +263,33 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
|
||||
# FIXME don't do copy
|
||||
if event == PipelineEvent.TRANSCRIPT:
|
||||
resp = transcript.add_event(event=event, data={
|
||||
"text": data.text,
|
||||
})
|
||||
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
|
||||
|
||||
elif event == PipelineEvent.TOPIC:
|
||||
topic = TranscriptTopic(
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
transcript=data.transcript,
|
||||
transcript=data.transcript.text,
|
||||
timestamp=data.timestamp,
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=topic.model_dump())
|
||||
resp = transcript.add_event(event=event, data=topic)
|
||||
transcript.upsert_topic(topic)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SUMMARY:
|
||||
final_summary = TranscriptFinalSummary(summary=data.summary)
|
||||
resp = transcript.add_event(event=event, data=final_summary)
|
||||
transcript.summary = final_summary
|
||||
|
||||
elif event == PipelineEvent.STATUS:
|
||||
resp = transcript.add_event(event=event, data=data)
|
||||
transcript.status = data.value
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event: {event}")
|
||||
return
|
||||
|
||||
# transmit to websocket clients
|
||||
await ws_manager.send_json(transcript_id, resp)
|
||||
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# FIXME try with locked session, RTC should not work
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from httpx import AsyncClient
|
||||
|
||||
@@ -61,7 +62,7 @@ async def dummy_llm():
|
||||
|
||||
class TestLLM(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
return {"text": "LLM RESULT"}
|
||||
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for RTC to end")
|
||||
|
||||
# XXX aiortc is long to close the connection
|
||||
# instead of waiting a long time, we just send a STOP
|
||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
@@ -141,10 +149,42 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
websocket_task.cancel()
|
||||
|
||||
# check events
|
||||
print(events)
|
||||
assert len(events) > 0
|
||||
assert events[0]["event"] == "TRANSCRIPT"
|
||||
assert events[0]["data"]["text"] == "Hello world"
|
||||
from pprint import pprint
|
||||
|
||||
pprint(events)
|
||||
|
||||
# get events list
|
||||
eventnames = [e["event"] for e in events]
|
||||
|
||||
# check events
|
||||
assert "TRANSCRIPT" in eventnames
|
||||
ev = events[eventnames.index("TRANSCRIPT")]
|
||||
assert ev["data"]["text"] == "Hello world"
|
||||
|
||||
assert "TOPIC" in eventnames
|
||||
ev = events[eventnames.index("TOPIC")]
|
||||
assert ev["data"]["id"]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# stop server
|
||||
# server.stop()
|
||||
|
||||
# check that transcript status in model is updated
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
@@ -97,7 +97,11 @@ export default function Recorder(props) {
|
||||
document.getElementById("play-btn").disabled = false;
|
||||
} else {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: { deviceId },
|
||||
audio: {
|
||||
deviceId,
|
||||
noiseSuppression: false,
|
||||
echoCancellation: false,
|
||||
},
|
||||
});
|
||||
await record.startRecording(stream);
|
||||
props.setStream(stream);
|
||||
|
||||
Reference in New Issue
Block a user