diff --git a/server/reflector/hatchet/workflows/diarization_pipeline.py b/server/reflector/hatchet/workflows/diarization_pipeline.py index 02668497..08f88554 100644 --- a/server/reflector/hatchet/workflows/diarization_pipeline.py +++ b/server/reflector/hatchet/workflows/diarization_pipeline.py @@ -48,6 +48,7 @@ from reflector.pipelines import topic_processing from reflector.processors import AudioFileWriterProcessor from reflector.processors.types import ( TitleSummary, + TitleSummaryWithId, Word, ) from reflector.processors.types import ( @@ -320,7 +321,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes target_language = participants_result.target_language - # Collect results from each track (don't mutate lists while iterating) track_words = [] padded_tracks = [] created_padded_files = set() @@ -372,6 +372,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: track_result = ctx.task_output(process_tracks) padded_tracks = track_result.padded_tracks + # TODO think of NonEmpty type to avoid those checks, e.g. sized.NonEmpty from https://github.com/antonagestam/phantom-types/ if not padded_tracks: raise ValueError("No padded tracks to mixdown") @@ -399,10 +400,10 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: raise ValueError("No decodable audio frames in any track") output_path = tempfile.mktemp(suffix=".mp3") - duration_ms = [0.0] # Mutable container for callback capture + duration_ms_callback_capture_container = [0.0] async def capture_duration(d): - duration_ms[0] = d + duration_ms_callback_capture_container[0] = d writer = AudioFileWriterProcessor(path=output_path, on_duration=capture_duration) @@ -436,7 +437,7 @@ async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: return MixdownResult( audio_key=storage_path, - duration=duration_ms[0], + duration=duration_ms_callback_capture_container[0], tracks_mixed=len(valid_urls), ) @@ -530,9 +531,6 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: TranscriptTopic, transcripts_controller, ) - from reflector.processors.types import ( # noqa: PLC0415 - TitleSummaryWithId as TitleSummaryWithIdProcessorType, - ) word_objects = [Word(**w) for w in words] transcript_type = TranscriptType(words=word_objects) @@ -550,7 +548,9 @@ async def detect_topics(input: PipelineInput, ctx: Context) -> TopicsResult: transcript=data.transcript.text, words=data.transcript.words, ) - if isinstance(data, TitleSummaryWithIdProcessorType): + if isinstance( + data, TitleSummaryWithId + ): # Celery parity: main_live_pipeline.py topic.id = data.id await transcripts_controller.upsert_topic(transcript, topic) await append_event_and_broadcast( @@ -720,10 +720,6 @@ async def finalize(input: PipelineInput, ctx: Context) -> FinalizeResult: TranscriptText, transcripts_controller, ) - from reflector.processors.types import ( # noqa: PLC0415 - Transcript as TranscriptType, - ) - from reflector.processors.types import Word # noqa: PLC0415 transcript = await transcripts_controller.get_by_id(input.transcript_id) if transcript is None: