server: refactor with diarization, logic works

This commit is contained in:
2023-10-27 15:59:27 +02:00
committed by Mathieu Virbel
parent 1c42473da0
commit 07c4d080c2
17 changed files with 387 additions and 169 deletions

View File

@@ -45,17 +45,16 @@ async def dummy_transcript():
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
print("transcripting", source_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
_time_idx = 0
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text=" world."),
Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
@@ -98,7 +97,17 @@ def ensure_casing():
@pytest.fixture
def sentence_tokenize():
with patch(
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
"reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
) as mock_sent_tokenize:
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
return {"broker_url": "memory://", "result_backend": "rpc"}

View File

@@ -32,7 +32,7 @@ class ThreadedUvicorn:
@pytest.fixture
async def appserver(tmpdir):
async def appserver(tmpdir, celery_session_app, celery_session_worker):
from reflector.settings import settings
from reflector.app import app
@@ -52,6 +52,13 @@ async def appserver(tmpdir):
settings.DATA_DIR = DATA_DIR
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
@@ -121,14 +128,20 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
# stop websocket task
websocket_task.cancel()
@@ -152,7 +165,7 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -169,23 +182,21 @@ async def test_transcript_rtc_and_websocket(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,
@@ -265,6 +276,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
await client.stop()
# wait the processing to finish
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
await asyncio.sleep(2)
# stop websocket task
@@ -289,7 +312,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -306,7 +329,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"