Merge branch 'server-api' of https://github.com/Monadical-SAS/reflector into server-api

This commit is contained in:
Koper
2023-08-09 19:12:27 +07:00
9 changed files with 178 additions and 42 deletions

36
server/poetry.lock generated
View File

@@ -2199,6 +2199,26 @@ files = [
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
] ]
[[package]]
name = "stamina"
version = "23.1.0"
description = "Production-grade retries made easy."
optional = false
python-versions = ">=3.8"
files = [
{file = "stamina-23.1.0-py3-none-any.whl", hash = "sha256:850de8c2c2469aabf42a4c02e7372eaa12c2eced78f2bfa34162b8676c2846e5"},
{file = "stamina-23.1.0.tar.gz", hash = "sha256:b16ce3d52d658aa75db813fc6a6661b770abfea915f72cda48e325f2a7854786"},
]
[package.dependencies]
tenacity = "*"
[package.extras]
dev = ["nox", "prometheus-client", "stamina[tests,typing]", "structlog", "tomli"]
docs = ["furo", "myst-parser", "prometheus-client", "sphinx", "sphinx-notfound-page", "structlog"]
tests = ["pytest", "pytest-asyncio"]
typing = ["mypy (>=1.4)"]
[[package]] [[package]]
name = "starlette" name = "starlette"
version = "0.27.0" version = "0.27.0"
@@ -2247,6 +2267,20 @@ files = [
[package.dependencies] [package.dependencies]
mpmath = ">=0.19" mpmath = ">=0.19"
[[package]]
name = "tenacity"
version = "8.2.2"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.6"
files = [
{file = "tenacity-8.2.2-py3-none-any.whl", hash = "sha256:2f277afb21b851637e8f52e6a613ff08734c347dc19ade928e519d7d2d8569b0"},
{file = "tenacity-8.2.2.tar.gz", hash = "sha256:43af037822bd0029025877f3b2d97cc4d7bb0c2991000a3d59d71517c5c969e0"},
]
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.13.3" version = "0.13.3"
@@ -2732,4 +2766,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "0c81bfdc623dc7a55ac16a0948bfb5b2d9391abd32bad0e665b0251169c7f7de" content-hash = "75afc46634677cd9afdf2ae66b320a8eaaa36d360d0ba187e5974b90810df44f"

View File

@@ -27,6 +27,7 @@ fastapi-pagination = "^0.12.6"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^23.7.0" black = "^23.7.0"
stamina = "^23.1.0"
[tool.poetry.group.client.dependencies] [tool.poetry.group.client.dependencies]

View File

@@ -199,6 +199,7 @@ class TranscriptionContext:
sorted_transcripts: dict sorted_transcripts: dict
data_channel: None # FIXME data_channel: None # FIXME
logger: None logger: None
status: str
def __init__(self, logger): def __init__(self, logger):
self.transcription_text = "" self.transcription_text = ""
@@ -206,4 +207,5 @@ class TranscriptionContext:
self.incremental_responses = [] self.incremental_responses = []
self.data_channel = None self.data_channel = None
self.sorted_transcripts = SortedDict() self.sorted_transcripts = SortedDict()
self.status = "idle"
self.logger = logger self.logger = logger

View File

@@ -14,6 +14,7 @@ class Processor:
if callback: if callback:
self.on(callback) self.on(callback)
self.uid = uuid4().hex self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
def set_pipeline(self, pipeline: "Pipeline"): def set_pipeline(self, pipeline: "Pipeline"):
@@ -65,6 +66,7 @@ class Processor:
""" """
# logger.debug(f"{self.__class__.__name__} push") # logger.debug(f"{self.__class__.__name__} push")
try: try:
self.flushed = False
return await self._push(data) return await self._push(data)
except Exception: except Exception:
self.logger.exception("Error in push") self.logger.exception("Error in push")
@@ -72,8 +74,12 @@ class Processor:
async def flush(self): async def flush(self):
""" """
Flush data to this processor Flush data to this processor
Works only one time, until another push is called
""" """
if self.flushed:
return
# logger.debug(f"{self.__class__.__name__} flush") # logger.debug(f"{self.__class__.__name__} flush")
self.flushed = True
return await self._flush() return await self._flush()
def describe(self, level=0): def describe(self, level=0):

View File

@@ -72,7 +72,7 @@ class StreamClient:
async def on_connectionstatechange(): async def on_connectionstatechange():
self.logger.info(f"Connection state is {pc.connectionState}") self.logger.info(f"Connection state is {pc.connectionState}")
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await self.stop()
self.pcs.discard(pc) self.pcs.discard(pc)
@pc.on("track") @pc.on("track")
@@ -87,7 +87,7 @@ class StreamClient:
self.pc.addTrack(audio) self.pc.addTrack(audio)
self.track_audio = audio self.track_audio = audio
channel = pc.createDataChannel("data-channel") self.channel = channel = pc.createDataChannel("data-channel")
self.logger = self.logger.bind(channel=channel.label) self.logger = self.logger.bind(channel=channel.label)
self.logger.info("Created by local party") self.logger.info("Created by local party")

View File

@@ -52,10 +52,15 @@ class RtcOffer(BaseModel):
type: str type: str
class StrValue(BaseModel):
value: str
class PipelineEvent(StrEnum): class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT" TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC" TOPIC = "TOPIC"
FINAL_SUMMARY = "FINAL_SUMMARY" FINAL_SUMMARY = "FINAL_SUMMARY"
STATUS = "STATUS"
async def rtc_offer_base( async def rtc_offer_base(
@@ -70,15 +75,30 @@ async def rtc_offer_base(
ctx = TranscriptionContext(logger=logger.bind(client=clientid)) ctx = TranscriptionContext(logger=logger.bind(client=clientid))
ctx.topics = [] ctx.topics = []
async def update_status(status: str):
changed = ctx.status != status
if changed:
ctx.status = status
if event_callback:
await event_callback(
event=PipelineEvent.STATUS,
args=event_callback_args,
data=StrValue(value=status),
)
# build pipeline callback # build pipeline callback
async def on_transcript(transcript: Transcript): async def on_transcript(transcript: Transcript):
ctx.logger.info("Transcript", transcript=transcript) ctx.logger.info("Transcript", transcript=transcript)
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback: if event_callback:
await event_callback( await event_callback(
event=PipelineEvent.TRANSCRIPT, event=PipelineEvent.TRANSCRIPT,
@@ -86,9 +106,7 @@ async def rtc_offer_base(
data=transcript, data=transcript,
) )
async def on_topic( async def on_topic(summary: TitleSummary):
summary: TitleSummary, event_callback=None, event_callback_args=None
):
# FIXME: make it incremental with the frontend, not send everything # FIXME: make it incremental with the frontend, not send everything
ctx.logger.info("Summary", summary=summary) ctx.logger.info("Summary", summary=summary)
ctx.topics.append( ctx.topics.append(
@@ -99,28 +117,36 @@ async def rtc_offer_base(
"desc": summary.summary, "desc": summary.summary,
} }
) )
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback: if event_callback:
await event_callback( await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
) )
async def on_final_summary( async def on_final_summary(summary: FinalSummary):
summary: FinalSummary, event_callback=None, event_callback_args=None
):
ctx.logger.info("FinalSummary", final_summary=summary) ctx.logger.info("FinalSummary", final_summary=summary)
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to RTC
if ctx.data_channel.readyState == "open":
result = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"summary": summary.summary,
"duration": summary.duration,
}
ctx.data_channel.send(dumps(result))
# send to callback (eg. websocket)
if event_callback: if event_callback:
await event_callback( await event_callback(
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary event=PipelineEvent.FINAL_SUMMARY,
args=event_callback_args,
data=summary,
) )
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
@@ -137,11 +163,13 @@ async def rtc_offer_base(
# handle RTC peer connection # handle RTC peer connection
pc = RTCPeerConnection() pc = RTCPeerConnection()
async def flush_pipeline_and_quit(): async def flush_pipeline_and_quit(close=True):
ctx.logger.info("Flushing pipeline") await update_status("processing")
await ctx.pipeline.flush() await ctx.pipeline.flush()
ctx.logger.debug("Closing peer connection") if close:
await pc.close() ctx.logger.debug("Closing peer connection")
await pc.close()
await update_status("ended")
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel): def on_datachannel(channel):
@@ -164,11 +192,14 @@ async def rtc_offer_base(
ctx.logger.info(f"Connection state: {pc.connectionState}") ctx.logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
elif pc.connectionState == "closed":
await flush_pipeline_and_quit(close=False)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
ctx.logger.info(f"Track {track.kind} received") ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track)) pc.addTrack(AudioStreamTrack(ctx, track))
asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)

View File

@@ -21,6 +21,10 @@ def generate_transcript_name():
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class TranscriptText(BaseModel):
text: str
class TranscriptTopic(BaseModel): class TranscriptTopic(BaseModel):
id: UUID = Field(default_factory=uuid4) id: UUID = Field(default_factory=uuid4)
title: str title: str
@@ -29,6 +33,10 @@ class TranscriptTopic(BaseModel):
timestamp: float timestamp: float
class TranscriptFinalSummary(BaseModel):
summary: str
class TranscriptEvent(BaseModel): class TranscriptEvent(BaseModel):
event: str event: str
data: dict data: dict
@@ -45,9 +53,10 @@ class Transcript(BaseModel):
topics: list[TranscriptTopic] = [] topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = [] events: list[TranscriptEvent] = []
def add_event(self, event: str, data): def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
self.events.append(TranscriptEvent(event=event, data=data)) ev = TranscriptEvent(event=event, data=data.model_dump())
return {"event": event, "data": data} self.events.append(ev)
return ev
def upsert_topic(self, topic: TranscriptTopic): def upsert_topic(self, topic: TranscriptTopic):
existing_topic = next((t for t in self.topics if t.id == topic.id), None) existing_topic = next((t for t in self.topics if t.id == topic.id), None)
@@ -219,7 +228,7 @@ async def transcript_events_websocket(transcript_id: UUID, websocket: WebSocket)
# on first connection, send all events # on first connection, send all events
for event in transcript.events: for event in transcript.events:
await websocket.send_json(event.model_dump()) await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended) # XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection # XXX send a final event to the client and close the connection
@@ -254,24 +263,33 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# FIXME don't do copy # FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT: if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data={ resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
"text": data.text,
})
elif event == PipelineEvent.TOPIC: elif event == PipelineEvent.TOPIC:
topic = TranscriptTopic( topic = TranscriptTopic(
title=data.title, title=data.title,
summary=data.summary, summary=data.summary,
transcript=data.transcript, transcript=data.transcript.text,
timestamp=data.timestamp, timestamp=data.timestamp,
) )
resp = transcript.add_event(event=event, data=topic.model_dump()) resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
elif event == PipelineEvent.FINAL_SUMMARY:
final_summary = TranscriptFinalSummary(summary=data.summary)
resp = transcript.add_event(event=event, data=final_summary)
transcript.summary = final_summary
elif event == PipelineEvent.STATUS:
resp = transcript.add_event(event=event, data=data)
transcript.status = data.value
else: else:
logger.warning(f"Unknown event: {event}") logger.warning(f"Unknown event: {event}")
return return
# transmit to websocket clients # transmit to websocket clients
await ws_manager.send_json(transcript_id, resp) await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc") @router.post("/transcripts/{transcript_id}/record/webrtc")

View File

@@ -4,6 +4,7 @@
# FIXME try with locked session, RTC should not work # FIXME try with locked session, RTC should not work
import pytest import pytest
import json
from unittest.mock import patch from unittest.mock import patch
from httpx import AsyncClient from httpx import AsyncClient
@@ -61,7 +62,7 @@ async def dummy_llm():
class TestLLM(LLM): class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, **kwargs):
return {"text": "LLM RESULT"} return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm: with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM() mock_llm.return_value = TestLLM()
@@ -132,6 +133,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
if timeout < 0: if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end") raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop() await client.stop()
# wait the processing to finish # wait the processing to finish
@@ -141,10 +149,42 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
websocket_task.cancel() websocket_task.cancel()
# check events # check events
print(events)
assert len(events) > 0 assert len(events) > 0
assert events[0]["event"] == "TRANSCRIPT" from pprint import pprint
assert events[0]["data"]["text"] == "Hello world"
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# stop server # stop server
# server.stop() # server.stop()
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"

View File

@@ -97,7 +97,11 @@ export default function Recorder(props) {
document.getElementById("play-btn").disabled = false; document.getElementById("play-btn").disabled = false;
} else { } else {
const stream = await navigator.mediaDevices.getUserMedia({ const stream = await navigator.mediaDevices.getUserMedia({
audio: { deviceId }, audio: {
deviceId,
noiseSuppression: false,
echoCancellation: false,
},
}); });
await record.startRecording(stream); await record.startRecording(stream);
props.setStream(stream); props.setStream(stream);