|
|
|
|
@@ -12,7 +12,6 @@ It is directly linked to our data model.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import functools
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from typing import Generic
|
|
|
|
|
|
|
|
|
|
@@ -90,31 +89,6 @@ def broadcast_to_sockets(func):
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transcript(func):
|
|
|
|
|
"""
|
|
|
|
|
Decorator to fetch the transcript from the database from the first argument
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@functools.wraps(func)
|
|
|
|
|
async def wrapper(**kwargs):
|
|
|
|
|
transcript_id = kwargs.pop("transcript_id")
|
|
|
|
|
async with get_session_context() as session:
|
|
|
|
|
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
|
|
|
if not transcript:
|
|
|
|
|
raise Exception(f"Transcript {transcript_id} not found")
|
|
|
|
|
|
|
|
|
|
tlogger = logger.bind(transcript_id=transcript.id)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
|
|
|
|
return result
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StrValue(BaseModel):
|
|
|
|
|
value: str
|
|
|
|
|
|
|
|
|
|
@@ -165,14 +139,9 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def transaction(self):
|
|
|
|
|
async def locked_session(self):
|
|
|
|
|
async with self.lock_transaction():
|
|
|
|
|
async with get_session_context() as session:
|
|
|
|
|
print(">>> SESSION USING", session, session.in_transaction())
|
|
|
|
|
if session.in_transaction():
|
|
|
|
|
yield session
|
|
|
|
|
else:
|
|
|
|
|
async with session.begin():
|
|
|
|
|
yield session
|
|
|
|
|
|
|
|
|
|
@broadcast_to_sockets
|
|
|
|
|
@@ -211,7 +180,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
|
|
|
|
|
@broadcast_to_sockets
|
|
|
|
|
async def on_transcript(self, data):
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
return await transcripts_controller.append_event(
|
|
|
|
|
session,
|
|
|
|
|
@@ -231,7 +200,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
)
|
|
|
|
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
|
|
|
|
topic.id = data.id
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
await transcripts_controller.upsert_topic(session, transcript, topic)
|
|
|
|
|
return await transcripts_controller.append_event(
|
|
|
|
|
@@ -244,7 +213,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
@broadcast_to_sockets
|
|
|
|
|
async def on_title(self, data):
|
|
|
|
|
final_title = TranscriptFinalTitle(title=data.title)
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
if not transcript.title:
|
|
|
|
|
await transcripts_controller.update(
|
|
|
|
|
@@ -264,7 +233,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
@broadcast_to_sockets
|
|
|
|
|
async def on_long_summary(self, data):
|
|
|
|
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
await transcripts_controller.update(
|
|
|
|
|
session,
|
|
|
|
|
@@ -285,7 +254,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
final_short_summary = TranscriptFinalShortSummary(
|
|
|
|
|
short_summary=data.short_summary
|
|
|
|
|
)
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
await transcripts_controller.update(
|
|
|
|
|
session,
|
|
|
|
|
@@ -303,7 +272,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
|
|
|
|
|
@broadcast_to_sockets
|
|
|
|
|
async def on_duration(self, data):
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
duration = TranscriptDuration(duration=data)
|
|
|
|
|
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
@@ -320,7 +289,7 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|
|
|
|
|
|
|
|
|
@broadcast_to_sockets
|
|
|
|
|
async def on_waveform(self, data):
|
|
|
|
|
async with self.transaction() as session:
|
|
|
|
|
async with self.locked_session() as session:
|
|
|
|
|
waveform = TranscriptWaveform(waveform=data)
|
|
|
|
|
|
|
|
|
|
transcript = await self.get_transcript(session)
|
|
|
|
|
@@ -483,8 +452,7 @@ class PipelineMainWaveform(PipelineMainFromTopics):
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
|
|
|
|
|
logger.info("Starting remove upload")
|
|
|
|
|
uploads = transcript.data_path.glob("upload.*")
|
|
|
|
|
@@ -493,16 +461,14 @@ async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Remove upload done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Starting waveform")
|
|
|
|
|
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
|
|
|
|
await runner.run()
|
|
|
|
|
logger.info("Waveform done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Starting convert to mp3")
|
|
|
|
|
|
|
|
|
|
# If the audio wav is not available, just skip
|
|
|
|
|
@@ -551,24 +517,21 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Upload mp3 done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Starting diarization")
|
|
|
|
|
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
|
|
|
|
await runner.run()
|
|
|
|
|
logger.info("Diarization done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_title(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_title(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Starting title")
|
|
|
|
|
runner = PipelineMainTitle(transcript_id=transcript.id)
|
|
|
|
|
await runner.run()
|
|
|
|
|
logger.info("Title done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
logger.info("Starting summaries")
|
|
|
|
|
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
|
|
|
|
await runner.run()
|
|
|
|
|
@@ -703,18 +666,27 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_remove_upload(*, transcript_id: str):
|
|
|
|
|
await pipeline_remove_upload(transcript_id=transcript_id)
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_remove_upload(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_waveform(*, transcript_id: str):
|
|
|
|
|
await pipeline_waveform(transcript_id=transcript_id)
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_waveform(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
await pipeline_waveform(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
|
|
|
|
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_convert_to_mp3(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@@ -726,32 +698,39 @@ async def task_pipeline_upload_mp3(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_diarization(*, transcript_id: str):
|
|
|
|
|
await pipeline_diarization(transcript_id=transcript_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_title(*, transcript_id: str):
|
|
|
|
|
await pipeline_title(transcript_id=transcript_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_final_summaries(*, transcript_id: str):
|
|
|
|
|
await pipeline_summaries(transcript_id=transcript_id)
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_diarization(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
await pipeline_diarization(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_cleanup_consent(
|
|
|
|
|
async def task_pipeline_title(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
await pipeline_title(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_final_summaries(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
await pipeline_summaries(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
|
|
|
|
|
await cleanup_consent(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_post_to_zulip(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger
|
|
|
|
|
):
|
|
|
|
|
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
@@ -759,7 +738,7 @@ async def task_pipeline_post_to_zulip(
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_cleanup_consent_taskiq(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger
|
|
|
|
|
):
|
|
|
|
|
await cleanup_consent(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
@@ -767,7 +746,7 @@ async def task_cleanup_consent_taskiq(
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_post_to_zulip_taskiq(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger
|
|
|
|
|
):
|
|
|
|
|
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
|
|
|
|
|
|
|
|
|
@@ -793,11 +772,9 @@ async def task_pipeline_post_sequential(*, transcript_id: str):
|
|
|
|
|
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@get_transcript
|
|
|
|
|
async def pipeline_process(transcript: Transcript, logger: Logger):
|
|
|
|
|
async def pipeline_process(session, transcript: Transcript, logger: Logger):
|
|
|
|
|
try:
|
|
|
|
|
if transcript.audio_location == "storage":
|
|
|
|
|
async with get_session_context() as session:
|
|
|
|
|
await transcripts_controller.download_mp3_from_storage(transcript)
|
|
|
|
|
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
|
|
|
|
await transcripts_controller.update(
|
|
|
|
|
@@ -838,7 +815,6 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
|
|
|
|
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
logger.error("Pipeline error", exc_info=exc)
|
|
|
|
|
async with get_session_context() as session:
|
|
|
|
|
await transcripts_controller.update(
|
|
|
|
|
session,
|
|
|
|
|
transcript,
|
|
|
|
|
@@ -852,5 +828,8 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@taskiq_broker.task
|
|
|
|
|
async def task_pipeline_process(*, transcript_id: str):
|
|
|
|
|
return await pipeline_process(transcript_id=transcript_id)
|
|
|
|
|
@with_session_and_transcript
|
|
|
|
|
async def task_pipeline_process(
|
|
|
|
|
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
|
|
|
):
|
|
|
|
|
return await pipeline_process(session, transcript=transcript, logger=logger)
|
|
|
|
|
|