mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: refactor with diarization, logic works
This commit is contained in:
16
server/poetry.lock
generated
16
server/poetry.lock
generated
@@ -2676,6 +2676,20 @@ pytest = ">=7.0.0"
|
|||||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-celery"
|
||||||
|
version = "0.0.0"
|
||||||
|
description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"},
|
||||||
|
{file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
celery = ">=4.4.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-cov"
|
name = "pytest-cov"
|
||||||
version = "4.1.0"
|
version = "4.1.0"
|
||||||
@@ -4064,4 +4078,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63"
|
content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d"
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ pytest-asyncio = "^0.21.1"
|
|||||||
pytest = "^7.4.0"
|
pytest = "^7.4.0"
|
||||||
httpx-ws = "^0.4.1"
|
httpx-ws = "^0.4.1"
|
||||||
pytest-httpx = "^0.23.1"
|
pytest-httpx = "^0.23.1"
|
||||||
|
pytest-celery = "^0.0.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.aws.dependencies]
|
[tool.poetry.group.aws.dependencies]
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class TranscriptTopic(BaseModel):
|
|||||||
title: str
|
title: str
|
||||||
summary: str
|
summary: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
|
duration: float | None = 0
|
||||||
text: str | None = None
|
text: str | None = None
|
||||||
words: list[ProcessorWord] = []
|
words: list[ProcessorWord] = []
|
||||||
|
|
||||||
@@ -264,7 +265,7 @@ class TranscriptController:
|
|||||||
"""
|
"""
|
||||||
A context manager for database transaction
|
A context manager for database transaction
|
||||||
"""
|
"""
|
||||||
async with database.transaction():
|
async with database.transaction(isolation="serializable"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
async def append_event(
|
async def append_event(
|
||||||
@@ -280,5 +281,16 @@ class TranscriptController:
|
|||||||
await self.update(transcript, {"events": transcript.events_dump()})
|
await self.update(transcript, {"events": transcript.events_dump()})
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
async def upsert_topic(
|
||||||
|
self,
|
||||||
|
transcript: Transcript,
|
||||||
|
topic: TranscriptTopic,
|
||||||
|
) -> TranscriptEvent:
|
||||||
|
"""
|
||||||
|
Append an event to a transcript
|
||||||
|
"""
|
||||||
|
transcript.upsert_topic(topic)
|
||||||
|
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||||
|
|
||||||
|
|
||||||
transcripts_controller = TranscriptController()
|
transcripts_controller = TranscriptController()
|
||||||
|
|||||||
@@ -11,8 +11,12 @@ It is decoupled to:
|
|||||||
It is directly linked to our data model.
|
It is directly linked to our data model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from celery import shared_task
|
||||||
|
from pydantic import BaseModel
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
Transcript,
|
Transcript,
|
||||||
TranscriptFinalLongSummary,
|
TranscriptFinalLongSummary,
|
||||||
@@ -25,6 +29,7 @@ from reflector.db.transcripts import (
|
|||||||
from reflector.pipelines.runner import PipelineRunner
|
from reflector.pipelines.runner import PipelineRunner
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerProcessor,
|
||||||
|
AudioDiarizationProcessor,
|
||||||
AudioFileWriterProcessor,
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
@@ -37,11 +42,13 @@ from reflector.processors import (
|
|||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptTranslatorProcessor,
|
TranscriptTranslatorProcessor,
|
||||||
)
|
)
|
||||||
from reflector.tasks.worker import celery
|
from reflector.processors.types import AudioDiarizationInput
|
||||||
|
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
|
||||||
|
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
|
|
||||||
|
|
||||||
def broadcast_to_socket(func):
|
def broadcast_to_sockets(func):
|
||||||
"""
|
"""
|
||||||
Decorator to broadcast transcript event to websockets
|
Decorator to broadcast transcript event to websockets
|
||||||
concerning this transcript
|
concerning this transcript
|
||||||
@@ -59,6 +66,10 @@ def broadcast_to_socket(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class StrValue(BaseModel):
|
||||||
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainBase(PipelineRunner):
|
class PipelineMainBase(PipelineRunner):
|
||||||
transcript_id: str
|
transcript_id: str
|
||||||
ws_room_id: str | None = None
|
ws_room_id: str | None = None
|
||||||
@@ -66,6 +77,7 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
# prepare websocket
|
# prepare websocket
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||||
self.ws_manager = get_ws_manager()
|
self.ws_manager = get_ws_manager()
|
||||||
|
|
||||||
@@ -78,15 +90,59 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def transaction(self):
|
||||||
|
async with self._lock:
|
||||||
|
async with transcripts_controller.transaction():
|
||||||
|
yield
|
||||||
|
|
||||||
class PipelineMainLive(PipelineMainBase):
|
@broadcast_to_sockets
|
||||||
audio_filename: Path | None = None
|
async def on_status(self, status):
|
||||||
source_language: str = "en"
|
# if it's the first part, update the status of the transcript
|
||||||
target_language: str = "en"
|
# but do not set the ended status yet.
|
||||||
|
if isinstance(self, PipelineMainLive):
|
||||||
|
status_mapping = {
|
||||||
|
"started": "recording",
|
||||||
|
"push": "recording",
|
||||||
|
"flush": "processing",
|
||||||
|
"error": "error",
|
||||||
|
}
|
||||||
|
elif isinstance(self, PipelineMainDiarization):
|
||||||
|
status_mapping = {
|
||||||
|
"push": "processing",
|
||||||
|
"flush": "processing",
|
||||||
|
"error": "error",
|
||||||
|
"ended": "ended",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception(f"Runner {self.__class__} is missing status mapping")
|
||||||
|
|
||||||
@broadcast_to_socket
|
# mutate to model status
|
||||||
|
status = status_mapping.get(status)
|
||||||
|
if not status:
|
||||||
|
return
|
||||||
|
|
||||||
|
# when the status of the pipeline changes, update the transcript
|
||||||
|
async with self.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
if status == transcript.status:
|
||||||
|
return
|
||||||
|
resp = await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="STATUS",
|
||||||
|
data=StrValue(value=status),
|
||||||
|
)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"status": status,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@broadcast_to_sockets
|
||||||
async def on_transcript(self, data):
|
async def on_transcript(self, data):
|
||||||
async with transcripts_controller.transaction():
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
@@ -94,7 +150,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
data=TranscriptText(text=data.text, translation=data.translation),
|
data=TranscriptText(text=data.text, translation=data.translation),
|
||||||
)
|
)
|
||||||
|
|
||||||
@broadcast_to_socket
|
@broadcast_to_sockets
|
||||||
async def on_topic(self, data):
|
async def on_topic(self, data):
|
||||||
topic = TranscriptTopic(
|
topic = TranscriptTopic(
|
||||||
title=data.title,
|
title=data.title,
|
||||||
@@ -103,14 +159,75 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
text=data.transcript.text,
|
text=data.transcript.text,
|
||||||
words=data.transcript.words,
|
words=data.transcript.words,
|
||||||
)
|
)
|
||||||
async with transcripts_controller.transaction():
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TOPIC",
|
event="TOPIC",
|
||||||
data=topic,
|
data=topic,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@broadcast_to_sockets
|
||||||
|
async def on_title(self, data):
|
||||||
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
|
async with self.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
if not transcript.title:
|
||||||
|
transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"title": final_title.title,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_TITLE",
|
||||||
|
data=final_title,
|
||||||
|
)
|
||||||
|
|
||||||
|
@broadcast_to_sockets
|
||||||
|
async def on_long_summary(self, data):
|
||||||
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
|
async with self.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"long_summary": final_long_summary.long_summary,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_LONG_SUMMARY",
|
||||||
|
data=final_long_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
@broadcast_to_sockets
|
||||||
|
async def on_short_summary(self, data):
|
||||||
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
|
short_summary=data.short_summary
|
||||||
|
)
|
||||||
|
async with self.transaction():
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"short_summary": final_short_summary.short_summary,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await transcripts_controller.append_event(
|
||||||
|
transcript=transcript,
|
||||||
|
event="FINAL_SHORT_SUMMARY",
|
||||||
|
data=final_short_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainLive(PipelineMainBase):
|
||||||
|
audio_filename: Path | None = None
|
||||||
|
source_language: str = "en"
|
||||||
|
target_language: str = "en"
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
@@ -125,96 +242,49 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||||
|
BroadcastProcessor(
|
||||||
|
processors=[
|
||||||
|
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||||
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
|
callback=self.on_long_summary
|
||||||
|
),
|
||||||
|
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||||
|
callback=self.on_short_summary
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
pipeline = Pipeline(*processors)
|
pipeline = Pipeline(*processors)
|
||||||
pipeline.options = self
|
pipeline.options = self
|
||||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
|
|
||||||
# when the pipeline ends, connect to the post pipeline
|
|
||||||
async def on_ended():
|
|
||||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
|
||||||
|
|
||||||
pipeline.on_ended = self
|
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
async def on_ended(self):
|
||||||
|
# when the pipeline ends, connect to the post pipeline
|
||||||
|
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||||
|
|
||||||
class PipelineMainPost(PipelineMainBase):
|
|
||||||
|
class PipelineMainDiarization(PipelineMainBase):
|
||||||
"""
|
"""
|
||||||
Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
|
Diarization is a long time process, so we do it in a separate pipeline
|
||||||
|
When done, adjust the short and final summary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@broadcast_to_socket
|
|
||||||
async def on_final_title(self, data):
|
|
||||||
final_title = TranscriptFinalTitle(title=data.title)
|
|
||||||
async with transcripts_controller.transaction():
|
|
||||||
transcript = await self.get_transcript()
|
|
||||||
if not transcript.title:
|
|
||||||
transcripts_controller.update(
|
|
||||||
self.transcript,
|
|
||||||
{
|
|
||||||
"title": final_title.title,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return await transcripts_controller.append_event(
|
|
||||||
transcript=transcript,
|
|
||||||
event="FINAL_TITLE",
|
|
||||||
data=final_title,
|
|
||||||
)
|
|
||||||
|
|
||||||
@broadcast_to_socket
|
|
||||||
async def on_final_long_summary(self, data):
|
|
||||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
|
||||||
async with transcripts_controller.transaction():
|
|
||||||
transcript = await self.get_transcript()
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"long_summary": final_long_summary.long_summary,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return await transcripts_controller.append_event(
|
|
||||||
transcript=transcript,
|
|
||||||
event="FINAL_LONG_SUMMARY",
|
|
||||||
data=final_long_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
@broadcast_to_socket
|
|
||||||
async def on_final_short_summary(self, data):
|
|
||||||
final_short_summary = TranscriptFinalShortSummary(
|
|
||||||
short_summary=data.short_summary
|
|
||||||
)
|
|
||||||
async with transcripts_controller.transaction():
|
|
||||||
transcript = await self.get_transcript()
|
|
||||||
await transcripts_controller.update(
|
|
||||||
transcript,
|
|
||||||
{
|
|
||||||
"short_summary": final_short_summary.short_summary,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return await transcripts_controller.append_event(
|
|
||||||
transcript=transcript,
|
|
||||||
event="FINAL_SHORT_SUMMARY",
|
|
||||||
data=final_short_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
self.prepare()
|
||||||
processors = [
|
processors = [
|
||||||
# add diarization
|
AudioDiarizationProcessor(),
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalTitleProcessor.as_threaded(
|
|
||||||
callback=self.on_final_title
|
|
||||||
),
|
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
callback=self.on_final_long_summary
|
callback=self.on_long_summary
|
||||||
),
|
),
|
||||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||||
callback=self.on_final_short_summary
|
callback=self.on_short_summary
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
@@ -222,9 +292,35 @@ class PipelineMainPost(PipelineMainBase):
|
|||||||
pipeline = Pipeline(*processors)
|
pipeline = Pipeline(*processors)
|
||||||
pipeline.options = self
|
pipeline.options = self
|
||||||
|
|
||||||
|
# now let's start the pipeline by pushing information to the
|
||||||
|
# first processor diarization processor
|
||||||
|
# XXX translation is lost when converting our data model to the processor model
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
topics = [
|
||||||
|
TitleSummaryProcessorType(
|
||||||
|
title=topic.title,
|
||||||
|
summary=topic.summary,
|
||||||
|
timestamp=topic.timestamp,
|
||||||
|
duration=topic.duration,
|
||||||
|
transcript=TranscriptProcessorType(words=topic.words),
|
||||||
|
)
|
||||||
|
for topic in transcript.topics
|
||||||
|
]
|
||||||
|
|
||||||
|
audio_diarization_input = AudioDiarizationInput(
|
||||||
|
audio_filename=transcript.audio_mp3_filename,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# as tempting to use pipeline.push, prefer to use the runner
|
||||||
|
# to let the start just do one job.
|
||||||
|
self.push(audio_diarization_input)
|
||||||
|
self.flush()
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
@celery.task
|
@shared_task
|
||||||
def task_pipeline_main_post(transcript_id: str):
|
def task_pipeline_main_post(transcript_id: str):
|
||||||
pass
|
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
||||||
|
runner.start_sync()
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ During its lifecycle, it will emit the following status:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
@@ -27,8 +26,6 @@ class PipelineRunner(BaseModel):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
status: str = "idle"
|
status: str = "idle"
|
||||||
on_status: Callable | None = None
|
|
||||||
on_ended: Callable | None = None
|
|
||||||
pipeline: Pipeline | None = None
|
pipeline: Pipeline | None = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -36,6 +33,10 @@ class PipelineRunner(BaseModel):
|
|||||||
self._q_cmd = asyncio.Queue()
|
self._q_cmd = asyncio.Queue()
|
||||||
self._ev_done = asyncio.Event()
|
self._ev_done = asyncio.Event()
|
||||||
self._is_first_push = True
|
self._is_first_push = True
|
||||||
|
self._logger = logger.bind(
|
||||||
|
runner=id(self),
|
||||||
|
runner_cls=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def create(self) -> Pipeline:
|
def create(self) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
@@ -50,33 +51,51 @@ class PipelineRunner(BaseModel):
|
|||||||
"""
|
"""
|
||||||
asyncio.get_event_loop().create_task(self.run())
|
asyncio.get_event_loop().create_task(self.run())
|
||||||
|
|
||||||
async def push(self, data):
|
def start_sync(self):
|
||||||
|
"""
|
||||||
|
Start the pipeline synchronously (for non-asyncio apps)
|
||||||
|
"""
|
||||||
|
asyncio.run(self.run())
|
||||||
|
|
||||||
|
def push(self, data):
|
||||||
"""
|
"""
|
||||||
Push data to the pipeline
|
Push data to the pipeline
|
||||||
"""
|
"""
|
||||||
await self._add_cmd("PUSH", data)
|
self._add_cmd("PUSH", data)
|
||||||
|
|
||||||
async def flush(self):
|
def flush(self):
|
||||||
"""
|
"""
|
||||||
Flush the pipeline
|
Flush the pipeline
|
||||||
"""
|
"""
|
||||||
await self._add_cmd("FLUSH", None)
|
self._add_cmd("FLUSH", None)
|
||||||
|
|
||||||
async def _add_cmd(self, cmd: str, data):
|
async def on_status(self, status):
|
||||||
|
"""
|
||||||
|
Called when the status of the pipeline changes
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_ended(self):
|
||||||
|
"""
|
||||||
|
Called when the pipeline ends
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _add_cmd(self, cmd: str, data):
|
||||||
"""
|
"""
|
||||||
Enqueue a command to be executed in the runner.
|
Enqueue a command to be executed in the runner.
|
||||||
Currently supported commands: PUSH, FLUSH
|
Currently supported commands: PUSH, FLUSH
|
||||||
"""
|
"""
|
||||||
await self._q_cmd.put([cmd, data])
|
self._q_cmd.put_nowait([cmd, data])
|
||||||
|
|
||||||
async def _set_status(self, status):
|
async def _set_status(self, status):
|
||||||
print("set_status", status)
|
self._logger.debug("Runner status updated", status=status)
|
||||||
self.status = status
|
self.status = status
|
||||||
if self.on_status:
|
if self.on_status:
|
||||||
try:
|
try:
|
||||||
await self.on_status(status)
|
await self.on_status(status)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("PipelineRunner status_callback error", error=e)
|
self._logger.exception("Runer error while setting status")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
@@ -95,8 +114,8 @@ class PipelineRunner(BaseModel):
|
|||||||
await func(data)
|
await func(data)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown command {cmd}")
|
raise Exception(f"Unknown command {cmd}")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("PipelineRunner error", error=e)
|
self._logger.exception("Runner error")
|
||||||
await self._set_status("error")
|
await self._set_status("error")
|
||||||
self._ev_done.set()
|
self._ev_done.set()
|
||||||
if self.on_ended:
|
if self.on_ended:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||||
|
from .audio_diarization import AudioDiarizationProcessor # noqa: F401
|
||||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||||
|
|||||||
65
server/reflector/processors/audio_diarization.py
Normal file
65
server/reflector/processors/audio_diarization.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationProcessor(Processor):
|
||||||
|
INPUT_TYPE = AudioDiarizationInput
|
||||||
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
async def _push(self, data: AudioDiarizationInput):
|
||||||
|
# Gather diarization data
|
||||||
|
diarization = [
|
||||||
|
{"start": 0.0, "stop": 4.9, "speaker": 2},
|
||||||
|
{"start": 5.6, "stop": 6.7, "speaker": 2},
|
||||||
|
{"start": 7.3, "stop": 8.9, "speaker": 2},
|
||||||
|
{"start": 7.3, "stop": 7.9, "speaker": 0},
|
||||||
|
{"start": 9.4, "stop": 11.2, "speaker": 2},
|
||||||
|
{"start": 9.7, "stop": 10.0, "speaker": 0},
|
||||||
|
{"start": 10.0, "stop": 10.1, "speaker": 0},
|
||||||
|
{"start": 11.7, "stop": 16.1, "speaker": 2},
|
||||||
|
{"start": 11.8, "stop": 12.1, "speaker": 1},
|
||||||
|
{"start": 16.4, "stop": 21.0, "speaker": 2},
|
||||||
|
{"start": 21.1, "stop": 22.6, "speaker": 2},
|
||||||
|
{"start": 24.7, "stop": 31.9, "speaker": 2},
|
||||||
|
{"start": 32.0, "stop": 32.8, "speaker": 1},
|
||||||
|
{"start": 33.4, "stop": 37.8, "speaker": 2},
|
||||||
|
{"start": 37.9, "stop": 40.3, "speaker": 0},
|
||||||
|
{"start": 39.2, "stop": 40.4, "speaker": 2},
|
||||||
|
{"start": 40.7, "stop": 41.4, "speaker": 0},
|
||||||
|
{"start": 41.6, "stop": 45.7, "speaker": 2},
|
||||||
|
{"start": 46.4, "stop": 53.1, "speaker": 2},
|
||||||
|
{"start": 53.6, "stop": 56.5, "speaker": 2},
|
||||||
|
{"start": 54.9, "stop": 75.4, "speaker": 1},
|
||||||
|
{"start": 57.3, "stop": 58.0, "speaker": 2},
|
||||||
|
{"start": 65.7, "stop": 66.0, "speaker": 2},
|
||||||
|
{"start": 75.8, "stop": 78.8, "speaker": 1},
|
||||||
|
{"start": 79.0, "stop": 82.6, "speaker": 1},
|
||||||
|
{"start": 83.2, "stop": 83.3, "speaker": 1},
|
||||||
|
{"start": 84.5, "stop": 94.3, "speaker": 1},
|
||||||
|
{"start": 95.1, "stop": 100.7, "speaker": 1},
|
||||||
|
{"start": 100.7, "stop": 102.0, "speaker": 0},
|
||||||
|
{"start": 100.7, "stop": 101.8, "speaker": 1},
|
||||||
|
{"start": 102.0, "stop": 103.0, "speaker": 1},
|
||||||
|
{"start": 103.0, "stop": 103.7, "speaker": 0},
|
||||||
|
{"start": 103.7, "stop": 103.8, "speaker": 1},
|
||||||
|
{"start": 103.8, "stop": 113.9, "speaker": 0},
|
||||||
|
{"start": 114.7, "stop": 117.0, "speaker": 0},
|
||||||
|
{"start": 117.0, "stop": 117.4, "speaker": 1},
|
||||||
|
]
|
||||||
|
|
||||||
|
# now reapply speaker to topics (if any)
|
||||||
|
# topics is a list[BaseModel] with an attribute words
|
||||||
|
# words is a list[BaseModel] with text, start and speaker attribute
|
||||||
|
|
||||||
|
print("IN DIARIZATION PROCESSOR", data)
|
||||||
|
|
||||||
|
# mutate in place
|
||||||
|
for topic in data.topics:
|
||||||
|
for word in topic.transcript.words:
|
||||||
|
for d in diarization:
|
||||||
|
if d["start"] <= word.start <= d["stop"]:
|
||||||
|
word.speaker = d["speaker"]
|
||||||
|
|
||||||
|
# emit them
|
||||||
|
for topic in data.topics:
|
||||||
|
await self.emit(topic)
|
||||||
@@ -382,3 +382,8 @@ class TranslationLanguages(BaseModel):
|
|||||||
|
|
||||||
def is_supported(self, lang_id: str) -> bool:
|
def is_supported(self, lang_id: str) -> bool:
|
||||||
return lang_id in self.supported_languages
|
return lang_id in self.supported_languages
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationInput(BaseModel):
|
||||||
|
audio_filename: Path
|
||||||
|
topics: list[TitleSummary]
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
import reflector.tasks.post_transcript # noqa
|
|
||||||
import reflector.tasks.worker # noqa
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
from celery import Celery
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
celery = Celery(__name__)
|
|
||||||
celery.conf.broker_url = settings.CELERY_BROKER_URL
|
|
||||||
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from enum import StrEnum
|
|
||||||
from json import loads
|
from json import loads
|
||||||
|
|
||||||
import av
|
import av
|
||||||
@@ -41,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
|
|||||||
ctx = self.ctx
|
ctx = self.ctx
|
||||||
frame = await self.track.recv()
|
frame = await self.track.recv()
|
||||||
try:
|
try:
|
||||||
await ctx.pipeline_runner.push(frame)
|
ctx.pipeline_runner.push(frame)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.logger.error("Pipeline error", error=e)
|
ctx.logger.error("Pipeline error", error=e)
|
||||||
return frame
|
return frame
|
||||||
@@ -52,19 +51,6 @@ class RtcOffer(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
|
||||||
class StrValue(BaseModel):
|
|
||||||
value: str
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineEvent(StrEnum):
|
|
||||||
TRANSCRIPT = "TRANSCRIPT"
|
|
||||||
TOPIC = "TOPIC"
|
|
||||||
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
|
|
||||||
STATUS = "STATUS"
|
|
||||||
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
|
|
||||||
FINAL_TITLE = "FINAL_TITLE"
|
|
||||||
|
|
||||||
|
|
||||||
async def rtc_offer_base(
|
async def rtc_offer_base(
|
||||||
params: RtcOffer,
|
params: RtcOffer,
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -90,7 +76,7 @@ async def rtc_offer_base(
|
|||||||
# - when we receive the close event, we do nothing.
|
# - when we receive the close event, we do nothing.
|
||||||
# 2. or the client close the connection
|
# 2. or the client close the connection
|
||||||
# and there is nothing to do because it is already closed
|
# and there is nothing to do because it is already closed
|
||||||
await ctx.pipeline_runner.flush()
|
ctx.pipeline_runner.flush()
|
||||||
if close:
|
if close:
|
||||||
ctx.logger.debug("Closing peer connection")
|
ctx.logger.debug("Closing peer connection")
|
||||||
await pc.close()
|
await pc.close()
|
||||||
|
|||||||
@@ -23,10 +23,9 @@ from reflector.ws_manager import get_ws_manager
|
|||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
from ._range_requests_response import range_requests_response
|
from ._range_requests_response import range_requests_response
|
||||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
ws_manager = get_ws_manager()
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Transcripts list
|
# Transcripts list
|
||||||
@@ -166,32 +165,17 @@ async def transcript_update(
|
|||||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
values = {"events": []}
|
values = {}
|
||||||
if info.name is not None:
|
if info.name is not None:
|
||||||
values["name"] = info.name
|
values["name"] = info.name
|
||||||
if info.locked is not None:
|
if info.locked is not None:
|
||||||
values["locked"] = info.locked
|
values["locked"] = info.locked
|
||||||
if info.long_summary is not None:
|
if info.long_summary is not None:
|
||||||
values["long_summary"] = info.long_summary
|
values["long_summary"] = info.long_summary
|
||||||
for transcript_event in transcript.events:
|
|
||||||
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
|
|
||||||
transcript_event["long_summary"] = info.long_summary
|
|
||||||
break
|
|
||||||
values["events"].extend(transcript.events)
|
|
||||||
if info.short_summary is not None:
|
if info.short_summary is not None:
|
||||||
values["short_summary"] = info.short_summary
|
values["short_summary"] = info.short_summary
|
||||||
for transcript_event in transcript.events:
|
|
||||||
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
|
|
||||||
transcript_event["short_summary"] = info.short_summary
|
|
||||||
break
|
|
||||||
values["events"].extend(transcript.events)
|
|
||||||
if info.title is not None:
|
if info.title is not None:
|
||||||
values["title"] = info.title
|
values["title"] = info.title
|
||||||
for transcript_event in transcript.events:
|
|
||||||
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
|
|
||||||
transcript_event["title"] = info.title
|
|
||||||
break
|
|
||||||
values["events"].extend(transcript.events)
|
|
||||||
await transcripts_controller.update(transcript, values)
|
await transcripts_controller.update(transcript, values)
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
@@ -295,6 +279,7 @@ async def transcript_events_websocket(
|
|||||||
# connect to websocket manager
|
# connect to websocket manager
|
||||||
# use ts:transcript_id as room id
|
# use ts:transcript_id as room id
|
||||||
room_id = f"ts:{transcript_id}"
|
room_id = f"ts:{transcript_id}"
|
||||||
|
ws_manager = get_ws_manager()
|
||||||
await ws_manager.add_user_to_room(room_id, websocket)
|
await ws_manager.add_user_to_room(room_id, websocket)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -303,9 +288,7 @@ async def transcript_events_websocket(
|
|||||||
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
|
||||||
# not necessary to be sent to the client; but keep the rest
|
# not necessary to be sent to the client; but keep the rest
|
||||||
name = event.event
|
name = event.event
|
||||||
if name == PipelineEvent.TRANSCRIPT:
|
if name in ("TRANSCRIPT", "STATUS"):
|
||||||
continue
|
|
||||||
if name == PipelineEvent.STATUS:
|
|
||||||
continue
|
continue
|
||||||
await websocket.send_json(event.model_dump(mode="json"))
|
await websocket.send_json(event.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|||||||
11
server/reflector/worker/app.py
Normal file
11
server/reflector/worker/app.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from celery import Celery
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
app = Celery(__name__)
|
||||||
|
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||||
|
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||||
|
app.autodiscover_tasks(
|
||||||
|
[
|
||||||
|
"reflector.pipelines.main_live_pipeline",
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -11,13 +11,12 @@ broadcast messages to all connected websockets.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
ws_manager = None
|
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubManager:
|
class RedisPubSubManager:
|
||||||
def __init__(self, host="localhost", port=6379):
|
def __init__(self, host="localhost", port=6379):
|
||||||
@@ -114,13 +113,14 @@ def get_ws_manager() -> WebsocketManager:
|
|||||||
ImportError: If the 'reflector.settings' module cannot be imported.
|
ImportError: If the 'reflector.settings' module cannot be imported.
|
||||||
RedisConnectionError: If there is an error connecting to the Redis server.
|
RedisConnectionError: If there is an error connecting to the Redis server.
|
||||||
"""
|
"""
|
||||||
global ws_manager
|
local = threading.local()
|
||||||
if ws_manager:
|
if hasattr(local, "ws_manager"):
|
||||||
return ws_manager
|
return local.ws_manager
|
||||||
|
|
||||||
pubsub_client = RedisPubSubManager(
|
pubsub_client = RedisPubSubManager(
|
||||||
host=settings.REDIS_HOST,
|
host=settings.REDIS_HOST,
|
||||||
port=settings.REDIS_PORT,
|
port=settings.REDIS_PORT,
|
||||||
)
|
)
|
||||||
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
|
||||||
|
local.ws_manager = ws_manager
|
||||||
return ws_manager
|
return ws_manager
|
||||||
|
|||||||
@@ -45,17 +45,16 @@ async def dummy_transcript():
|
|||||||
from reflector.processors.types import AudioFile, Transcript, Word
|
from reflector.processors.types import AudioFile, Transcript, Word
|
||||||
|
|
||||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
||||||
async def _transcript(self, data: AudioFile):
|
_time_idx = 0
|
||||||
source_language = self.get_pref("audio:source_language", "en")
|
|
||||||
print("transcripting", source_language)
|
|
||||||
print("pipeline", self.pipeline)
|
|
||||||
print("prefs", self.pipeline.prefs)
|
|
||||||
|
|
||||||
|
async def _transcript(self, data: AudioFile):
|
||||||
|
i = self._time_idx
|
||||||
|
self._time_idx += 2
|
||||||
return Transcript(
|
return Transcript(
|
||||||
text="Hello world.",
|
text="Hello world.",
|
||||||
words=[
|
words=[
|
||||||
Word(start=0.0, end=1.0, text="Hello"),
|
Word(start=i, end=i + 1, text="Hello", speaker=0),
|
||||||
Word(start=1.0, end=2.0, text=" world."),
|
Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -98,7 +97,17 @@ def ensure_casing():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sentence_tokenize():
|
def sentence_tokenize():
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
|
"reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
|
||||||
) as mock_sent_tokenize:
|
) as mock_sent_tokenize:
|
||||||
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
|
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_enable_logging():
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_config():
|
||||||
|
return {"broker_url": "memory://", "result_backend": "rpc"}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class ThreadedUvicorn:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def appserver(tmpdir):
|
async def appserver(tmpdir, celery_session_app, celery_session_worker):
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
|
|
||||||
@@ -52,6 +52,13 @@ async def appserver(tmpdir):
|
|||||||
settings.DATA_DIR = DATA_DIR
|
settings.DATA_DIR = DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_includes():
|
||||||
|
return ["reflector.pipelines.main_live_pipeline"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("celery_session_app")
|
||||||
|
@pytest.mark.usefixtures("celery_session_worker")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_rtc_and_websocket(
|
async def test_transcript_rtc_and_websocket(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
@@ -121,14 +128,20 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
# XXX aiortc is long to close the connection
|
# XXX aiortc is long to close the connection
|
||||||
# instead of waiting a long time, we just send a STOP
|
# instead of waiting a long time, we just send a STOP
|
||||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
|
|
||||||
# wait the processing to finish
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
|
|
||||||
await client.stop()
|
await client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# wait the processing to finish
|
||||||
await asyncio.sleep(2)
|
timeout = 20
|
||||||
|
while True:
|
||||||
|
# fetch the transcript and check if it is ended
|
||||||
|
resp = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
if resp.json()["status"] in ("ended", "error"):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
if resp.json()["status"] != "ended":
|
||||||
|
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||||
|
|
||||||
# stop websocket task
|
# stop websocket task
|
||||||
websocket_task.cancel()
|
websocket_task.cancel()
|
||||||
@@ -152,7 +165,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
ev = events[eventnames.index("TOPIC")]
|
ev = events[eventnames.index("TOPIC")]
|
||||||
assert ev["data"]["id"]
|
assert ev["data"]["id"]
|
||||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||||
assert ev["data"]["transcript"].startswith("Hello world.")
|
assert ev["data"]["text"].startswith("Hello world.")
|
||||||
assert ev["data"]["timestamp"] == 0.0
|
assert ev["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
assert "FINAL_LONG_SUMMARY" in eventnames
|
||||||
@@ -169,23 +182,21 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
|
|
||||||
# check status order
|
# check status order
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
assert statuses == ["recording", "processing", "ended"]
|
assert statuses.index("recording") < statuses.index("processing")
|
||||||
|
assert statuses.index("processing") < statuses.index("ended")
|
||||||
|
|
||||||
# ensure the last event received is ended
|
# ensure the last event received is ended
|
||||||
assert events[-1]["event"] == "STATUS"
|
assert events[-1]["event"] == "STATUS"
|
||||||
assert events[-1]["data"]["value"] == "ended"
|
assert events[-1]["data"]["value"] == "ended"
|
||||||
|
|
||||||
# check that transcript status in model is updated
|
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["status"] == "ended"
|
|
||||||
|
|
||||||
# check that audio/mp3 is available
|
# check that audio/mp3 is available
|
||||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.headers["Content-Type"] == "audio/mpeg"
|
assert resp.headers["Content-Type"] == "audio/mpeg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("celery_session_app")
|
||||||
|
@pytest.mark.usefixtures("celery_session_worker")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_rtc_and_websocket_and_fr(
|
async def test_transcript_rtc_and_websocket_and_fr(
|
||||||
tmpdir,
|
tmpdir,
|
||||||
@@ -265,6 +276,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
await client.stop()
|
await client.stop()
|
||||||
|
|
||||||
# wait the processing to finish
|
# wait the processing to finish
|
||||||
|
timeout = 20
|
||||||
|
while True:
|
||||||
|
# fetch the transcript and check if it is ended
|
||||||
|
resp = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
if resp.json()["status"] == "ended":
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
if resp.json()["status"] != "ended":
|
||||||
|
raise TimeoutError("Timeout while waiting for transcript to be ended")
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
# stop websocket task
|
# stop websocket task
|
||||||
@@ -289,7 +312,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
ev = events[eventnames.index("TOPIC")]
|
ev = events[eventnames.index("TOPIC")]
|
||||||
assert ev["data"]["id"]
|
assert ev["data"]["id"]
|
||||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||||
assert ev["data"]["transcript"].startswith("Hello world.")
|
assert ev["data"]["text"].startswith("Hello world.")
|
||||||
assert ev["data"]["timestamp"] == 0.0
|
assert ev["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
assert "FINAL_LONG_SUMMARY" in eventnames
|
||||||
@@ -306,7 +329,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
|
|
||||||
# check status order
|
# check status order
|
||||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
assert statuses == ["recording", "processing", "ended"]
|
assert statuses.index("recording") < statuses.index("processing")
|
||||||
|
assert statuses.index("processing") < statuses.index("ended")
|
||||||
|
|
||||||
# ensure the last event received is ended
|
# ensure the last event received is ended
|
||||||
assert events[-1]["event"] == "STATUS"
|
assert events[-1]["event"] == "STATUS"
|
||||||
|
|||||||
Reference in New Issue
Block a user