Merge branch 'server-api' of https://github.com/Monadical-SAS/reflector into server-api

This commit is contained in:
Koper
2023-08-09 19:12:27 +07:00
9 changed files with 178 additions and 42 deletions

36
server/poetry.lock generated
View File

@@ -2199,6 +2199,26 @@ files = [
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
]
[[package]]
name = "stamina"
version = "23.1.0"
description = "Production-grade retries made easy."
optional = false
python-versions = ">=3.8"
files = [
{file = "stamina-23.1.0-py3-none-any.whl", hash = "sha256:850de8c2c2469aabf42a4c02e7372eaa12c2eced78f2bfa34162b8676c2846e5"},
{file = "stamina-23.1.0.tar.gz", hash = "sha256:b16ce3d52d658aa75db813fc6a6661b770abfea915f72cda48e325f2a7854786"},
]
[package.dependencies]
tenacity = "*"
[package.extras]
dev = ["nox", "prometheus-client", "stamina[tests,typing]", "structlog", "tomli"]
docs = ["furo", "myst-parser", "prometheus-client", "sphinx", "sphinx-notfound-page", "structlog"]
tests = ["pytest", "pytest-asyncio"]
typing = ["mypy (>=1.4)"]
[[package]]
name = "starlette"
version = "0.27.0"
@@ -2247,6 +2267,20 @@ files = [
[package.dependencies]
mpmath = ">=0.19"
[[package]]
name = "tenacity"
version = "8.2.2"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.6"
files = [
{file = "tenacity-8.2.2-py3-none-any.whl", hash = "sha256:2f277afb21b851637e8f52e6a613ff08734c347dc19ade928e519d7d2d8569b0"},
{file = "tenacity-8.2.2.tar.gz", hash = "sha256:43af037822bd0029025877f3b2d97cc4d7bb0c2991000a3d59d71517c5c969e0"},
]
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]]
name = "tokenizers"
version = "0.13.3"
@@ -2732,4 +2766,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "0c81bfdc623dc7a55ac16a0948bfb5b2d9391abd32bad0e665b0251169c7f7de"
content-hash = "75afc46634677cd9afdf2ae66b320a8eaaa36d360d0ba187e5974b90810df44f"

View File

@@ -27,6 +27,7 @@ fastapi-pagination = "^0.12.6"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
stamina = "^23.1.0"
[tool.poetry.group.client.dependencies]

View File

@@ -199,6 +199,7 @@ class TranscriptionContext:
sorted_transcripts: dict
data_channel: None # FIXME
logger: None
status: str
def __init__(self, logger):
self.transcription_text = ""
@@ -206,4 +207,5 @@ class TranscriptionContext:
self.incremental_responses = []
self.data_channel = None
self.sorted_transcripts = SortedDict()
self.status = "idle"
self.logger = logger

View File

@@ -14,6 +14,7 @@ class Processor:
if callback:
self.on(callback)
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
def set_pipeline(self, pipeline: "Pipeline"):
@@ -65,6 +66,7 @@ class Processor:
"""
# logger.debug(f"{self.__class__.__name__} push")
try:
self.flushed = False
return await self._push(data)
except Exception:
self.logger.exception("Error in push")
@@ -72,8 +74,12 @@ class Processor:
async def flush(self):
"""
Flush data to this processor
Works only one time, until another push is called
"""
if self.flushed:
return
# logger.debug(f"{self.__class__.__name__} flush")
self.flushed = True
return await self._flush()
def describe(self, level=0):

View File

@@ -72,7 +72,7 @@ class StreamClient:
async def on_connectionstatechange():
self.logger.info(f"Connection state is {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
await self.stop()
self.pcs.discard(pc)
@pc.on("track")
@@ -87,7 +87,7 @@ class StreamClient:
self.pc.addTrack(audio)
self.track_audio = audio
channel = pc.createDataChannel("data-channel")
self.channel = channel = pc.createDataChannel("data-channel")
self.logger = self.logger.bind(channel=channel.label)
self.logger.info("Created by local party")

View File

@@ -52,10 +52,15 @@ class RtcOffer(BaseModel):
type: str
class StrValue(BaseModel):
value: str
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_SUMMARY = "FINAL_SUMMARY"
STATUS = "STATUS"
async def rtc_offer_base(
@@ -70,15 +75,30 @@ async def rtc_offer_base(
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
ctx.topics = []
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback
async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TRANSCRIPT,
@@ -86,9 +106,7 @@ async def rtc_offer_base(
data=transcript,
)
async def on_topic(
summary: TitleSummary, event_callback=None, event_callback_args=None
):
async def on_topic(summary: TitleSummary):
# FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary)
ctx.topics.append(
@@ -99,18 +117,23 @@ async def rtc_offer_base(
"desc": summary.summary,
}
)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
)
async def on_final_summary(
summary: FinalSummary, event_callback=None, event_callback_args=None
):
async def on_final_summary(summary: FinalSummary):
ctx.logger.info("FinalSummary", final_summary=summary)
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
@@ -118,9 +141,12 @@ async def rtc_offer_base(
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
event=PipelineEvent.FINAL_SUMMARY,
args=event_callback_args,
data=summary,
)
# create a context for the whole rtc transaction
@@ -137,11 +163,13 @@ async def rtc_offer_base(
# handle RTC peer connection
pc = RTCPeerConnection()
async def flush_pipeline_and_quit():
ctx.logger.info("Flushing pipeline")
async def flush_pipeline_and_quit(close=True):
await update_status("processing")
await ctx.pipeline.flush()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
await update_status("ended")
@pc.on("datachannel")
def on_datachannel(channel):
@@ -164,11 +192,14 @@ async def rtc_offer_base(
ctx.logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
elif pc.connectionState == "closed":
await flush_pipeline_and_quit(close=False)
@pc.on("track")
def on_track(track):
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track))
asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer)

View File

@@ -21,6 +21,10 @@ def generate_transcript_name():
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class TranscriptText(BaseModel):
text: str
class TranscriptTopic(BaseModel):
id: UUID = Field(default_factory=uuid4)
title: str
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
timestamp: float
class TranscriptFinalSummary(BaseModel):
summary: str
class TranscriptEvent(BaseModel):
event: str
data: dict
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
def add_event(self, event: str, data):
self.events.append(TranscriptEvent(event=event, data=data))
return {"event": event, "data": data}
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None)
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
# on first connection, send all events
for event in transcript.events:
await websocket.send_json(event.model_dump())
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
@@ -254,24 +263,33 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data={
"text": data.text,
})
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
transcript=data.transcript,
transcript=data.transcript.text,
timestamp=data.timestamp,
)
resp = transcript.add_event(event=event, data=topic.model_dump())
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
elif event == PipelineEvent.FINAL_SUMMARY:
final_summary = TranscriptFinalSummary(summary=data.summary)
resp = transcript.add_event(event=event, data=final_summary)
transcript.summary = final_summary
elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data)
transcript.status = data.value
else:
logger.warning(f"Unknown event: {event}")
return
# transmit to websocket clients
await ws_manager.send_json(transcript_id, resp)
await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")

View File

@@ -4,6 +4,7 @@
# FIXME try with locked session, RTC should not work
import pytest
import json
from unittest.mock import patch
from httpx import AsyncClient
@@ -61,7 +62,7 @@ async def dummy_llm():
class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs):
return {"text": "LLM RESULT"}
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
@@ -141,10 +149,42 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
websocket_task.cancel()
# check events
print(events)
assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT"
assert events[0]["data"]["text"] == "Hello world"
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# stop server
# server.stop()
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"

View File

@@ -97,7 +97,11 @@ export default function Recorder(props) {
document.getElementById("play-btn").disabled = false;
} else {
const stream = await navigator.mediaDevices.getUserMedia({
audio: { deviceId },
audio: {
deviceId,
noiseSuppression: false,
echoCancellation: false,
},
});
await record.startRecording(stream);
props.setStream(stream);