server: refactor with diarization, logic works

This commit is contained in:
2023-10-27 15:59:27 +02:00
committed by Mathieu Virbel
parent 1c42473da0
commit 07c4d080c2
17 changed files with 387 additions and 169 deletions

16
server/poetry.lock generated
View File

@@ -2676,6 +2676,20 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-celery"
version = "0.0.0"
description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest"
optional = false
python-versions = "*"
files = [
{file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"},
{file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"},
]
[package.dependencies]
celery = ">=4.4.0"
[[package]] [[package]]
name = "pytest-cov" name = "pytest-cov"
version = "4.1.0" version = "4.1.0"
@@ -4064,4 +4078,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63" content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d"

View File

@@ -49,6 +49,7 @@ pytest-asyncio = "^0.21.1"
pytest = "^7.4.0" pytest = "^7.4.0"
httpx-ws = "^0.4.1" httpx-ws = "^0.4.1"
pytest-httpx = "^0.23.1" pytest-httpx = "^0.23.1"
pytest-celery = "^0.0.0"
[tool.poetry.group.aws.dependencies] [tool.poetry.group.aws.dependencies]

View File

@@ -62,6 +62,7 @@ class TranscriptTopic(BaseModel):
title: str title: str
summary: str summary: str
timestamp: float timestamp: float
duration: float | None = 0
text: str | None = None text: str | None = None
words: list[ProcessorWord] = [] words: list[ProcessorWord] = []
@@ -264,7 +265,7 @@ class TranscriptController:
""" """
A context manager for database transaction A context manager for database transaction
""" """
async with database.transaction(): async with database.transaction(isolation="serializable"):
yield yield
async def append_event( async def append_event(
@@ -280,5 +281,16 @@ class TranscriptController:
await self.update(transcript, {"events": transcript.events_dump()}) await self.update(transcript, {"events": transcript.events_dump()})
return resp return resp
async def upsert_topic(
self,
transcript: Transcript,
topic: TranscriptTopic,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -11,8 +11,12 @@ It is decoupled to:
It is directly linked to our data model. It is directly linked to our data model.
""" """
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from celery import shared_task
from pydantic import BaseModel
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
@@ -25,6 +29,7 @@ from reflector.db.transcripts import (
from reflector.pipelines.runner import PipelineRunner from reflector.pipelines.runner import PipelineRunner
from reflector.processors import ( from reflector.processors import (
AudioChunkerProcessor, AudioChunkerProcessor,
AudioDiarizationProcessor,
AudioFileWriterProcessor, AudioFileWriterProcessor,
AudioMergeProcessor, AudioMergeProcessor,
AudioTranscriptAutoProcessor, AudioTranscriptAutoProcessor,
@@ -37,11 +42,13 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor, TranscriptTranslatorProcessor,
) )
from reflector.tasks.worker import celery from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
def broadcast_to_socket(func): def broadcast_to_sockets(func):
""" """
Decorator to broadcast transcript event to websockets Decorator to broadcast transcript event to websockets
concerning this transcript concerning this transcript
@@ -59,6 +66,10 @@ def broadcast_to_socket(func):
return wrapper return wrapper
class StrValue(BaseModel):
value: str
class PipelineMainBase(PipelineRunner): class PipelineMainBase(PipelineRunner):
transcript_id: str transcript_id: str
ws_room_id: str | None = None ws_room_id: str | None = None
@@ -66,6 +77,7 @@ class PipelineMainBase(PipelineRunner):
def prepare(self): def prepare(self):
# prepare websocket # prepare websocket
self._lock = asyncio.Lock()
self.ws_room_id = f"ts:{self.transcript_id}" self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager() self.ws_manager = get_ws_manager()
@@ -78,15 +90,59 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
@asynccontextmanager
async def transaction(self):
async with self._lock:
async with transcripts_controller.transaction():
yield
class PipelineMainLive(PipelineMainBase): @broadcast_to_sockets
audio_filename: Path | None = None async def on_status(self, status):
source_language: str = "en" # if it's the first part, update the status of the transcript
target_language: str = "en" # but do not set the ended status yet.
if isinstance(self, PipelineMainLive):
status_mapping = {
"started": "recording",
"push": "recording",
"flush": "processing",
"error": "error",
}
elif isinstance(self, PipelineMainDiarization):
status_mapping = {
"push": "processing",
"flush": "processing",
"error": "error",
"ended": "ended",
}
else:
raise Exception(f"Runner {self.__class__} is missing status mapping")
@broadcast_to_socket # mutate to model status
status = status_mapping.get(status)
if not status:
return
# when the status of the pipeline changes, update the transcript
async with self.transaction():
transcript = await self.get_transcript()
if status == transcript.status:
return
resp = await transcripts_controller.append_event(
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await transcripts_controller.update(
transcript,
{
"status": status,
},
)
return resp
@broadcast_to_sockets
async def on_transcript(self, data): async def on_transcript(self, data):
async with transcripts_controller.transaction(): async with self.transaction():
transcript = await self.get_transcript() transcript = await self.get_transcript()
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
transcript=transcript, transcript=transcript,
@@ -94,7 +150,7 @@ class PipelineMainLive(PipelineMainBase):
data=TranscriptText(text=data.text, translation=data.translation), data=TranscriptText(text=data.text, translation=data.translation),
) )
@broadcast_to_socket @broadcast_to_sockets
async def on_topic(self, data): async def on_topic(self, data):
topic = TranscriptTopic( topic = TranscriptTopic(
title=data.title, title=data.title,
@@ -103,14 +159,75 @@ class PipelineMainLive(PipelineMainBase):
text=data.transcript.text, text=data.transcript.text,
words=data.transcript.words, words=data.transcript.words,
) )
async with transcripts_controller.transaction(): async with self.transaction():
transcript = await self.get_transcript() transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
transcript=transcript, transcript=transcript,
event="TOPIC", event="TOPIC",
data=topic, data=topic,
) )
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_sockets
async def on_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
@@ -125,96 +242,49 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
] ]
pipeline = Pipeline(*processors) pipeline = Pipeline(*processors)
pipeline.options = self pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.set_pref("audio:target_language", transcript.target_language)
# when the pipeline ends, connect to the post pipeline
async def on_ended():
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
pipeline.on_ended = self
return pipeline return pipeline
async def on_ended(self):
# when the pipeline ends, connect to the post pipeline
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
class PipelineMainPost(PipelineMainBase):
class PipelineMainDiarization(PipelineMainBase):
""" """
Implement the rest of the main pipeline, triggered after PipelineMainLive ended. Diarization is a long time process, so we do it in a separate pipeline
When done, adjust the short and final summary
""" """
@broadcast_to_socket
async def on_final_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
self.transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
)
@broadcast_to_socket
async def on_final_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
)
@broadcast_to_socket
async def on_final_short_summary(self, data):
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with transcripts_controller.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
)
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
self.prepare() self.prepare()
processors = [ processors = [
# add diarization AudioDiarizationProcessor(),
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalTitleProcessor.as_threaded(
callback=self.on_final_title
),
TranscriptFinalLongSummaryProcessor.as_threaded( TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_final_long_summary callback=self.on_long_summary
), ),
TranscriptFinalShortSummaryProcessor.as_threaded( TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_final_short_summary callback=self.on_short_summary
), ),
] ]
), ),
@@ -222,9 +292,35 @@ class PipelineMainPost(PipelineMainBase):
pipeline = Pipeline(*processors) pipeline = Pipeline(*processors)
pipeline.options = self pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
topics = [
TitleSummaryProcessorType(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
audio_diarization_input = AudioDiarizationInput(
audio_filename=transcript.audio_mp3_filename,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
self.push(audio_diarization_input)
self.flush()
return pipeline return pipeline
@celery.task @shared_task
def task_pipeline_main_post(transcript_id: str): def task_pipeline_main_post(transcript_id: str):
pass runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()

View File

@@ -16,7 +16,6 @@ During its lifecycle, it will emit the following status:
""" """
import asyncio import asyncio
from typing import Callable
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from reflector.logger import logger from reflector.logger import logger
@@ -27,8 +26,6 @@ class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle" status: str = "idle"
on_status: Callable | None = None
on_ended: Callable | None = None
pipeline: Pipeline | None = None pipeline: Pipeline | None = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
@@ -36,6 +33,10 @@ class PipelineRunner(BaseModel):
self._q_cmd = asyncio.Queue() self._q_cmd = asyncio.Queue()
self._ev_done = asyncio.Event() self._ev_done = asyncio.Event()
self._is_first_push = True self._is_first_push = True
self._logger = logger.bind(
runner=id(self),
runner_cls=self.__class__.__name__,
)
def create(self) -> Pipeline: def create(self) -> Pipeline:
""" """
@@ -50,33 +51,51 @@ class PipelineRunner(BaseModel):
""" """
asyncio.get_event_loop().create_task(self.run()) asyncio.get_event_loop().create_task(self.run())
async def push(self, data): def start_sync(self):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
asyncio.run(self.run())
def push(self, data):
""" """
Push data to the pipeline Push data to the pipeline
""" """
await self._add_cmd("PUSH", data) self._add_cmd("PUSH", data)
async def flush(self): def flush(self):
""" """
Flush the pipeline Flush the pipeline
""" """
await self._add_cmd("FLUSH", None) self._add_cmd("FLUSH", None)
async def _add_cmd(self, cmd: str, data): async def on_status(self, status):
"""
Called when the status of the pipeline changes
"""
pass
async def on_ended(self):
"""
Called when the pipeline ends
"""
pass
def _add_cmd(self, cmd: str, data):
""" """
Enqueue a command to be executed in the runner. Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH Currently supported commands: PUSH, FLUSH
""" """
await self._q_cmd.put([cmd, data]) self._q_cmd.put_nowait([cmd, data])
async def _set_status(self, status): async def _set_status(self, status):
print("set_status", status) self._logger.debug("Runner status updated", status=status)
self.status = status self.status = status
if self.on_status: if self.on_status:
try: try:
await self.on_status(status) await self.on_status(status)
except Exception as e: except Exception:
logger.error("PipelineRunner status_callback error", error=e) self._logger.exception("Runer error while setting status")
async def run(self): async def run(self):
try: try:
@@ -95,8 +114,8 @@ class PipelineRunner(BaseModel):
await func(data) await func(data)
else: else:
raise Exception(f"Unknown command {cmd}") raise Exception(f"Unknown command {cmd}")
except Exception as e: except Exception:
logger.error("PipelineRunner error", error=e) self._logger.exception("Runner error")
await self._set_status("error") await self._set_status("error")
self._ev_done.set() self._ev_done.set()
if self.on_ended: if self.on_ended:

View File

@@ -1,4 +1,5 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_diarization import AudioDiarizationProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401

View File

@@ -0,0 +1,65 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
# Gather diarization data
diarization = [
{"start": 0.0, "stop": 4.9, "speaker": 2},
{"start": 5.6, "stop": 6.7, "speaker": 2},
{"start": 7.3, "stop": 8.9, "speaker": 2},
{"start": 7.3, "stop": 7.9, "speaker": 0},
{"start": 9.4, "stop": 11.2, "speaker": 2},
{"start": 9.7, "stop": 10.0, "speaker": 0},
{"start": 10.0, "stop": 10.1, "speaker": 0},
{"start": 11.7, "stop": 16.1, "speaker": 2},
{"start": 11.8, "stop": 12.1, "speaker": 1},
{"start": 16.4, "stop": 21.0, "speaker": 2},
{"start": 21.1, "stop": 22.6, "speaker": 2},
{"start": 24.7, "stop": 31.9, "speaker": 2},
{"start": 32.0, "stop": 32.8, "speaker": 1},
{"start": 33.4, "stop": 37.8, "speaker": 2},
{"start": 37.9, "stop": 40.3, "speaker": 0},
{"start": 39.2, "stop": 40.4, "speaker": 2},
{"start": 40.7, "stop": 41.4, "speaker": 0},
{"start": 41.6, "stop": 45.7, "speaker": 2},
{"start": 46.4, "stop": 53.1, "speaker": 2},
{"start": 53.6, "stop": 56.5, "speaker": 2},
{"start": 54.9, "stop": 75.4, "speaker": 1},
{"start": 57.3, "stop": 58.0, "speaker": 2},
{"start": 65.7, "stop": 66.0, "speaker": 2},
{"start": 75.8, "stop": 78.8, "speaker": 1},
{"start": 79.0, "stop": 82.6, "speaker": 1},
{"start": 83.2, "stop": 83.3, "speaker": 1},
{"start": 84.5, "stop": 94.3, "speaker": 1},
{"start": 95.1, "stop": 100.7, "speaker": 1},
{"start": 100.7, "stop": 102.0, "speaker": 0},
{"start": 100.7, "stop": 101.8, "speaker": 1},
{"start": 102.0, "stop": 103.0, "speaker": 1},
{"start": 103.0, "stop": 103.7, "speaker": 0},
{"start": 103.7, "stop": 103.8, "speaker": 1},
{"start": 103.8, "stop": 113.9, "speaker": 0},
{"start": 114.7, "stop": 117.0, "speaker": 0},
{"start": 117.0, "stop": 117.4, "speaker": 1},
]
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
print("IN DIARIZATION PROCESSOR", data)
# mutate in place
for topic in data.topics:
for word in topic.transcript.words:
for d in diarization:
if d["start"] <= word.start <= d["stop"]:
word.speaker = d["speaker"]
# emit them
for topic in data.topics:
await self.emit(topic)

View File

@@ -382,3 +382,8 @@ class TranslationLanguages(BaseModel):
def is_supported(self, lang_id: str) -> bool: def is_supported(self, lang_id: str) -> bool:
return lang_id in self.supported_languages return lang_id in self.supported_languages
class AudioDiarizationInput(BaseModel):
audio_filename: Path
topics: list[TitleSummary]

View File

@@ -1,2 +0,0 @@
import reflector.tasks.post_transcript # noqa
import reflector.tasks.worker # noqa

View File

@@ -1,6 +0,0 @@
from celery import Celery
from reflector.settings import settings
celery = Celery(__name__)
celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from enum import StrEnum
from json import loads from json import loads
import av import av
@@ -41,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
ctx = self.ctx ctx = self.ctx
frame = await self.track.recv() frame = await self.track.recv()
try: try:
await ctx.pipeline_runner.push(frame) ctx.pipeline_runner.push(frame)
except Exception as e: except Exception as e:
ctx.logger.error("Pipeline error", error=e) ctx.logger.error("Pipeline error", error=e)
return frame return frame
@@ -52,19 +51,6 @@ class RtcOffer(BaseModel):
type: str type: str
class StrValue(BaseModel):
value: str
class PipelineEvent(StrEnum):
TRANSCRIPT = "TRANSCRIPT"
TOPIC = "TOPIC"
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
STATUS = "STATUS"
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
FINAL_TITLE = "FINAL_TITLE"
async def rtc_offer_base( async def rtc_offer_base(
params: RtcOffer, params: RtcOffer,
request: Request, request: Request,
@@ -90,7 +76,7 @@ async def rtc_offer_base(
# - when we receive the close event, we do nothing. # - when we receive the close event, we do nothing.
# 2. or the client close the connection # 2. or the client close the connection
# and there is nothing to do because it is already closed # and there is nothing to do because it is already closed
await ctx.pipeline_runner.flush() ctx.pipeline_runner.flush()
if close: if close:
ctx.logger.debug("Closing peer connection") ctx.logger.debug("Closing peer connection")
await pc.close() await pc.close()

View File

@@ -23,10 +23,9 @@ from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter() router = APIRouter()
ws_manager = get_ws_manager()
# ============================================================== # ==============================================================
# Transcripts list # Transcripts list
@@ -166,32 +165,17 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
values = {"events": []} values = {}
if info.name is not None: if info.name is not None:
values["name"] = info.name values["name"] = info.name
if info.locked is not None: if info.locked is not None:
values["locked"] = info.locked values["locked"] = info.locked
if info.long_summary is not None: if info.long_summary is not None:
values["long_summary"] = info.long_summary values["long_summary"] = info.long_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
transcript_event["long_summary"] = info.long_summary
break
values["events"].extend(transcript.events)
if info.short_summary is not None: if info.short_summary is not None:
values["short_summary"] = info.short_summary values["short_summary"] = info.short_summary
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
transcript_event["short_summary"] = info.short_summary
break
values["events"].extend(transcript.events)
if info.title is not None: if info.title is not None:
values["title"] = info.title values["title"] = info.title
for transcript_event in transcript.events:
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
transcript_event["title"] = info.title
break
values["events"].extend(transcript.events)
await transcripts_controller.update(transcript, values) await transcripts_controller.update(transcript, values)
return transcript return transcript
@@ -295,6 +279,7 @@ async def transcript_events_websocket(
# connect to websocket manager # connect to websocket manager
# use ts:transcript_id as room id # use ts:transcript_id as room id
room_id = f"ts:{transcript_id}" room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket) await ws_manager.add_user_to_room(room_id, websocket)
try: try:
@@ -303,9 +288,7 @@ async def transcript_events_websocket(
# for now, do not send TRANSCRIPT or STATUS options - theses are live event # for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest # not necessary to be sent to the client; but keep the rest
name = event.event name = event.event
if name == PipelineEvent.TRANSCRIPT: if name in ("TRANSCRIPT", "STATUS"):
continue
if name == PipelineEvent.STATUS:
continue continue
await websocket.send_json(event.model_dump(mode="json")) await websocket.send_json(event.model_dump(mode="json"))

View File

@@ -0,0 +1,11 @@
from celery import Celery
from reflector.settings import settings
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
]
)

View File

@@ -11,13 +11,12 @@ broadcast messages to all connected websockets.
import asyncio import asyncio
import json import json
import threading
import redis.asyncio as redis import redis.asyncio as redis
from fastapi import WebSocket from fastapi import WebSocket
from reflector.settings import settings from reflector.settings import settings
ws_manager = None
class RedisPubSubManager: class RedisPubSubManager:
def __init__(self, host="localhost", port=6379): def __init__(self, host="localhost", port=6379):
@@ -114,13 +113,14 @@ def get_ws_manager() -> WebsocketManager:
ImportError: If the 'reflector.settings' module cannot be imported. ImportError: If the 'reflector.settings' module cannot be imported.
RedisConnectionError: If there is an error connecting to the Redis server. RedisConnectionError: If there is an error connecting to the Redis server.
""" """
global ws_manager local = threading.local()
if ws_manager: if hasattr(local, "ws_manager"):
return ws_manager return local.ws_manager
pubsub_client = RedisPubSubManager( pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST, host=settings.REDIS_HOST,
port=settings.REDIS_PORT, port=settings.REDIS_PORT,
) )
ws_manager = WebsocketManager(pubsub_client=pubsub_client) ws_manager = WebsocketManager(pubsub_client=pubsub_client)
local.ws_manager = ws_manager
return ws_manager return ws_manager

View File

@@ -45,17 +45,16 @@ async def dummy_transcript():
from reflector.processors.types import AudioFile, Transcript, Word from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor): class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile): _time_idx = 0
source_language = self.get_pref("audio:source_language", "en")
print("transcripting", source_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
async def _transcript(self, data: AudioFile):
i = self._time_idx
self._time_idx += 2
return Transcript( return Transcript(
text="Hello world.", text="Hello world.",
words=[ words=[
Word(start=0.0, end=1.0, text="Hello"), Word(start=i, end=i + 1, text="Hello", speaker=0),
Word(start=1.0, end=2.0, text=" world."), Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
], ],
) )
@@ -98,7 +97,17 @@ def ensure_casing():
@pytest.fixture @pytest.fixture
def sentence_tokenize(): def sentence_tokenize():
with patch( with patch(
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize" "reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
) as mock_sent_tokenize: ) as mock_sent_tokenize:
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"] mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
yield yield
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
return {"broker_url": "memory://", "result_backend": "rpc"}

View File

@@ -32,7 +32,7 @@ class ThreadedUvicorn:
@pytest.fixture @pytest.fixture
async def appserver(tmpdir): async def appserver(tmpdir, celery_session_app, celery_session_worker):
from reflector.settings import settings from reflector.settings import settings
from reflector.app import app from reflector.app import app
@@ -52,6 +52,13 @@ async def appserver(tmpdir):
settings.DATA_DIR = DATA_DIR settings.DATA_DIR = DATA_DIR
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_rtc_and_websocket( async def test_transcript_rtc_and_websocket(
tmpdir, tmpdir,
@@ -121,14 +128,20 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection # XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP # instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"})) client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop() await client.stop()
# wait the processing to finish # wait the processing to finish
await asyncio.sleep(2) timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
# stop websocket task # stop websocket task
websocket_task.cancel() websocket_task.cancel()
@@ -152,7 +165,7 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("TOPIC")] ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"] assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY" assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0 assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames assert "FINAL_LONG_SUMMARY" in eventnames
@@ -169,23 +182,21 @@ async def test_transcript_rtc_and_websocket(
# check status order # check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"] assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended # ensure the last event received is ended
assert events[-1]["event"] == "STATUS" assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended" assert events[-1]["data"]["value"] == "ended"
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio/mp3 is available # check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3") resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg" assert resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr( async def test_transcript_rtc_and_websocket_and_fr(
tmpdir, tmpdir,
@@ -265,6 +276,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
await client.stop() await client.stop()
# wait the processing to finish # wait the processing to finish
timeout = 20
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
await asyncio.sleep(2) await asyncio.sleep(2)
# stop websocket task # stop websocket task
@@ -289,7 +312,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
ev = events[eventnames.index("TOPIC")] ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"] assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY" assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0 assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames assert "FINAL_LONG_SUMMARY" in eventnames
@@ -306,7 +329,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check status order # check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"] assert statuses.index("recording") < statuses.index("processing")
assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended # ensure the last event received is ended
assert events[-1]["event"] == "STATUS" assert events[-1]["event"] == "STATUS"