fix: file pipeline status reporting and websocket updates (#589)

* feat: use file pipeline for upload and reprocess action

* fix: make file pipeline correctly report status events

* fix: duplication of transcripts_controller

* fix: tests

* test: fix file upload test

* test: fix reprocess

* fix: also patch from main_file_pipeline

(how patch is done is dependent of file import unfortunately)
This commit is contained in:
2025-08-29 00:58:14 -06:00
committed by GitHub
parent 55cc8637c6
commit 9dfd76996f
9 changed files with 170 additions and 50 deletions

View File

@@ -122,6 +122,15 @@ def generate_transcript_name() -> str:
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
TranscriptStatus = Literal[
"idle", "uploaded", "recording", "processing", "error", "ended"
]
class StrValue(BaseModel):
value: str
class AudioWaveform(BaseModel): class AudioWaveform(BaseModel):
data: list[float] data: list[float]
@@ -185,7 +194,7 @@ class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name) name: str = Field(default_factory=generate_transcript_name)
status: str = "idle" status: TranscriptStatus = "idle"
duration: float = 0 duration: float = 0
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
title: str | None = None title: str | None = None
@@ -732,5 +741,27 @@ class TranscriptController:
transcript.delete_participant(participant_id) transcript.delete_participant(participant_id)
await self.update(transcript, {"participants": transcript.participants_dump()}) await self.update(transcript, {"participants": transcript.participants_dump()})
async def set_status(
self, transcript_id: str, status: TranscriptStatus
) -> TranscriptEvent | None:
"""
Update the status of a transcript
Will add an event STATUS + update the status field of transcript
"""
async with self.transaction():
transcript = await self.get_by_id(transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status:
return
resp = await self.append_event(
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await self.update(transcript, {"status": status})
return resp
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -15,10 +15,15 @@ from celery import shared_task
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptStatus,
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask from reflector.pipelines.main_live_pipeline import (
PipelineMainBase,
asynctask,
broadcast_to_sockets,
)
from reflector.processors import ( from reflector.processors import (
AudioFileWriterProcessor, AudioFileWriterProcessor,
TranscriptFinalSummaryProcessor, TranscriptFinalSummaryProcessor,
@@ -83,12 +88,27 @@ class PipelineMainFile(PipelineMainBase):
exc_info=result, exc_info=result,
) )
@broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status)
async def process(self, file_path: Path): async def process(self, file_path: Path):
"""Main entry point for file processing""" """Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}") self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await self.get_transcript() transcript = await self.get_transcript()
# Clear transcript as we're going to regenerate everything
async with self.transaction():
await transcripts_controller.update(
transcript,
{
"events": [],
"topics": [],
},
)
# Extract audio and write to transcript location # Extract audio and write to transcript location
audio_path = await self.extract_and_write_audio(file_path, transcript) audio_path = await self.extract_and_write_audio(file_path, transcript)
@@ -105,6 +125,8 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete") self.logger.info("File pipeline complete")
await transcripts_controller.set_status(transcript.id, "ended")
async def extract_and_write_audio( async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript self, file_path: Path, transcript: Transcript
) -> Path: ) -> Path:
@@ -362,14 +384,21 @@ async def task_pipeline_file_process(*, transcript_id: str):
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
# Find the file to process
audio_file = next(transcript.data_path.glob("upload.*"), None)
if not audio_file:
audio_file = next(transcript.data_path.glob("audio.*"), None)
if not audio_file:
raise Exception("No audio file found to process")
# Run file pipeline
pipeline = PipelineMainFile(transcript_id=transcript_id) pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process(audio_file)
try:
await pipeline.set_status(transcript_id, "processing")
# Find the file to process
audio_file = next(transcript.data_path.glob("upload.*"), None)
if not audio_file:
audio_file = next(transcript.data_path.glob("audio.*"), None)
if not audio_file:
raise Exception("No audio file found to process")
await pipeline.process(audio_file)
except Exception:
await pipeline.set_status(transcript_id, "error")
raise

View File

@@ -32,6 +32,7 @@ from reflector.db.transcripts import (
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
TranscriptFinalShortSummary, TranscriptFinalShortSummary,
TranscriptFinalTitle, TranscriptFinalTitle,
TranscriptStatus,
TranscriptText, TranscriptText,
TranscriptTopic, TranscriptTopic,
TranscriptWaveform, TranscriptWaveform,
@@ -188,8 +189,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
] ]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def lock_transaction(self):
# This lock is to prevent multiple processor starting adding
# into event array at the same time
async with self._lock: async with self._lock:
yield
@asynccontextmanager
async def transaction(self):
async with self.lock_transaction():
async with transcripts_controller.transaction(): async with transcripts_controller.transaction():
yield yield
@@ -198,14 +206,14 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
# if it's the first part, update the status of the transcript # if it's the first part, update the status of the transcript
# but do not set the ended status yet. # but do not set the ended status yet.
if isinstance(self, PipelineMainLive): if isinstance(self, PipelineMainLive):
status_mapping = { status_mapping: dict[str, TranscriptStatus] = {
"started": "recording", "started": "recording",
"push": "recording", "push": "recording",
"flush": "processing", "flush": "processing",
"error": "error", "error": "error",
} }
elif isinstance(self, PipelineMainFinalSummaries): elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = { status_mapping: dict[str, TranscriptStatus] = {
"push": "processing", "push": "processing",
"flush": "processing", "flush": "processing",
"error": "error", "error": "error",
@@ -221,22 +229,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
return return
# when the status of the pipeline changes, update the transcript # when the status of the pipeline changes, update the transcript
async with self.transaction(): async with self._lock:
transcript = await self.get_transcript() return await transcripts_controller.set_status(self.transcript_id, status)
if status == transcript.status:
return
resp = await transcripts_controller.append_event(
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
)
await transcripts_controller.update(
transcript,
{
"status": status,
},
)
return resp
@broadcast_to_sockets @broadcast_to_sockets
async def on_transcript(self, data): async def on_transcript(self, data):

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel
import reflector.auth as auth import reflector.auth as auth
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import task_pipeline_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
router = APIRouter() router = APIRouter()
@@ -40,7 +40,7 @@ async def transcript_process(
return ProcessStatus(status="already running") return ProcessStatus(status="already running")
# schedule a background task process the file # schedule a background task process the file
task_pipeline_process.delay(transcript_id=transcript_id) task_pipeline_file_process.delay(transcript_id=transcript_id)
return ProcessStatus(status="ok") return ProcessStatus(status="ok")

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel
import reflector.auth as auth import reflector.auth as auth
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import task_pipeline_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
router = APIRouter() router = APIRouter()
@@ -92,6 +92,6 @@ async def transcript_record_upload(
await transcripts_controller.update(transcript, {"status": "uploaded"}) await transcripts_controller.update(transcript, {"status": "uploaded"})
# launch a background task to process the file # launch a background task to process the file
task_pipeline_process.delay(transcript_id=transcript_id) task_pipeline_file_process.delay(transcript_id=transcript_id)
return UploadStatus(status="ok") return UploadStatus(status="ok")

View File

@@ -178,6 +178,63 @@ async def dummy_diarization():
yield yield
@pytest.fixture
async def dummy_file_transcript():
from reflector.processors.file_transcript import FileTranscriptProcessor
from reflector.processors.types import Transcript, Word
class TestFileTranscriptProcessor(FileTranscriptProcessor):
async def _transcript(self, data):
return Transcript(
text="Hello world. How are you today?",
words=[
Word(start=0.0, end=0.5, text="Hello", speaker=0),
Word(start=0.5, end=0.6, text=" ", speaker=0),
Word(start=0.6, end=1.0, text="world", speaker=0),
Word(start=1.0, end=1.1, text=".", speaker=0),
Word(start=1.1, end=1.2, text=" ", speaker=0),
Word(start=1.2, end=1.5, text="How", speaker=0),
Word(start=1.5, end=1.6, text=" ", speaker=0),
Word(start=1.6, end=1.8, text="are", speaker=0),
Word(start=1.8, end=1.9, text=" ", speaker=0),
Word(start=1.9, end=2.1, text="you", speaker=0),
Word(start=2.1, end=2.2, text=" ", speaker=0),
Word(start=2.2, end=2.5, text="today", speaker=0),
Word(start=2.5, end=2.6, text="?", speaker=0),
],
)
with patch(
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileTranscriptProcessor()
yield
@pytest.fixture
async def dummy_file_diarization():
from reflector.processors.file_diarization import (
FileDiarizationOutput,
FileDiarizationProcessor,
)
from reflector.processors.types import DiarizationSegment
class TestFileDiarizationProcessor(FileDiarizationProcessor):
async def _diarize(self, data):
return FileDiarizationOutput(
diarization=[
DiarizationSegment(start=0.0, end=1.1, speaker=0),
DiarizationSegment(start=1.2, end=2.6, speaker=1),
]
)
with patch(
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
) as mock_auto:
mock_auto.return_value = TestFileDiarizationProcessor()
yield
@pytest.fixture @pytest.fixture
async def dummy_transcript_translator(): async def dummy_transcript_translator():
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
@@ -238,9 +295,13 @@ async def dummy_storage():
with ( with (
patch("reflector.storage.base.Storage.get_instance") as mock_storage, patch("reflector.storage.base.Storage.get_instance") as mock_storage,
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts, patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
patch(
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
) as mock_get_transcripts2,
): ):
mock_storage.return_value = dummy mock_storage.return_value = dummy
mock_get_transcripts.return_value = dummy mock_get_transcripts.return_value = dummy
mock_get_transcripts2.return_value = dummy
yield yield
@@ -260,7 +321,10 @@ def celery_config():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def celery_includes(): def celery_includes():
return ["reflector.pipelines.main_live_pipeline"] return [
"reflector.pipelines.main_live_pipeline",
"reflector.pipelines.main_file_pipeline",
]
@pytest.fixture @pytest.fixture
@@ -302,7 +366,7 @@ async def fake_transcript_with_topics(tmpdir, client):
transcript = await transcripts_controller.get_by_id(tid) transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None assert transcript is not None
await transcripts_controller.update(transcript, {"status": "finished"}) await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location # manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename audio_filename = transcript.audio_mp3_filename

View File

@@ -19,7 +19,7 @@ async def fake_transcript(tmpdir, client):
transcript = await transcripts_controller.get_by_id(tid) transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None assert transcript is not None
await transcripts_controller.update(transcript, {"status": "finished"}) await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location # manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename audio_filename = transcript.audio_mp3_filename

View File

@@ -29,10 +29,10 @@ async def client(app_lifespan):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_process( async def test_transcript_process(
tmpdir, tmpdir,
whisper_transcript,
dummy_llm, dummy_llm,
dummy_processors, dummy_processors,
dummy_diarization, dummy_file_transcript,
dummy_file_diarization,
dummy_storage, dummy_storage,
client, client,
): ):
@@ -56,8 +56,8 @@ async def test_transcript_process(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait for processing to finish (max 10 minutes) # wait for processing to finish (max 1 minute)
timeout_seconds = 600 # 10 minutes timeout_seconds = 60
start_time = time.monotonic() start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
@@ -75,9 +75,10 @@ async def test_transcript_process(
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
await asyncio.sleep(2)
# wait for processing to finish (max 10 minutes) # wait for processing to finish (max 1 minute)
timeout_seconds = 600 # 10 minutes timeout_seconds = 60
start_time = time.monotonic() start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
@@ -99,4 +100,4 @@ async def test_transcript_process(
response = await client.get(f"/transcripts/{tid}/topics") response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"] assert "Hello world. How are you today?" in response.json()[0]["transcript"]

View File

@@ -12,7 +12,8 @@ async def test_transcript_upload_file(
tmpdir, tmpdir,
dummy_llm, dummy_llm,
dummy_processors, dummy_processors,
dummy_diarization, dummy_file_transcript,
dummy_file_diarization,
dummy_storage, dummy_storage,
client, client,
): ):
@@ -36,8 +37,8 @@ async def test_transcript_upload_file(
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
# wait the processing to finish (max 10 minutes) # wait the processing to finish (max 1 minute)
timeout_seconds = 600 # 10 minutes timeout_seconds = 60
start_time = time.monotonic() start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds: while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended # fetch the transcript and check if it is ended
@@ -47,7 +48,7 @@ async def test_transcript_upload_file(
break break
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
pytest.fail(f"Processing timed out after {timeout_seconds} seconds") return pytest.fail(f"Processing timed out after {timeout_seconds} seconds")
# check the transcript is ended # check the transcript is ended
transcript = resp.json() transcript = resp.json()
@@ -59,4 +60,4 @@ async def test_transcript_upload_file(
response = await client.get(f"/transcripts/{tid}/topics") response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"] assert "Hello world. How are you today?" in response.json()[0]["transcript"]