server: add audio_location and move to external storage if possible

This commit is contained in:
2023-11-16 14:34:33 +01:00
committed by Mathieu Virbel
parent 88f443e2c2
commit 06b29d9bd4
5 changed files with 238 additions and 114 deletions

View File

@@ -0,0 +1,43 @@
"""audio_location
Revision ID: f819277e5169
Revises: 4814901632bc
Create Date: 2023-11-16 10:29:09.351664
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f819277e5169"
down_revision: Union[str, None] = "4814901632bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"audio_location", sa.String(), server_default="local", nullable=False
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"share_mode",
sa.VARCHAR(),
server_default=sa.text("'private'"),
nullable=False,
),
)
# ### end Alembic commands ###

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import Storage
transcripts = sqlalchemy.Table( transcripts = sqlalchemy.Table(
"transcript", "transcript",
@@ -27,20 +28,33 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional # with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column("user_id", sqlalchemy.String),
) )
def generate_uuid4(): def generate_uuid4() -> str:
return str(uuid4()) return str(uuid4())
def generate_transcript_name(): def generate_transcript_name() -> str:
now = datetime.utcnow() now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
def get_storage() -> Storage:
return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND,
settings_prefix="TRANSCRIPT_STORAGE_",
)
class AudioWaveform(BaseModel): class AudioWaveform(BaseModel):
data: list[float] data: list[float]
@@ -133,6 +147,10 @@ class Transcript(BaseModel):
def data_path(self): def data_path(self):
return Path(settings.DATA_DIR) / self.id return Path(settings.DATA_DIR) / self.id
@property
def audio_wav_filename(self):
return self.data_path / "audio.wav"
@property @property
def audio_mp3_filename(self): def audio_mp3_filename(self):
return self.data_path / "audio.mp3" return self.data_path / "audio.mp3"
@@ -157,6 +175,40 @@ class Transcript(BaseModel):
return AudioWaveform(data=data) return AudioWaveform(data=data)
async def get_audio_url(self) -> str:
if self.audio_location == "local":
return self._generate_local_audio_link()
elif self.audio_location == "storage":
return await self._generate_storage_audio_link()
raise Exception(f"Unknown audio location {self.audio_location}")
async def _generate_storage_audio_link(self) -> str:
return await get_storage().get_file_url(self.storage_audio_path)
def _generate_local_audio_link(self) -> str:
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=self.id,
)
url = f"{settings.BASE_URL}{path}"
if self.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": self.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
return url
class TranscriptController: class TranscriptController:
async def get_all( async def get_all(
@@ -292,15 +344,18 @@ class TranscriptController:
""" """
Move mp3 file to storage Move mp3 file to storage
""" """
from reflector.storage import Storage
storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) # store the audio on external storage
await storage.put_file( await get_storage().put_file(
transcript.storage_audio_path, transcript.storage_audio_path,
self.audio_mp3_filename.read_bytes(), transcript.audio_mp3_filename.read_bytes(),
) )
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"}) await self.update(transcript, {"audio_location": "storage"})
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -14,12 +14,9 @@ It is directly linked to our data model.
import asyncio import asyncio
import functools import functools
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from celery import shared_task from celery import chord, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptDuration, TranscriptDuration,
@@ -56,7 +53,7 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import Logger from structlog import BoundLogger as Logger
def asynctask(f): def asynctask(f):
@@ -97,13 +94,17 @@ def get_transcript(func):
Decorator to fetch the transcript from the database from the first argument Decorator to fetch the transcript from the database from the first argument
""" """
async def wrapper(self, **kwargs): async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id") transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript: if not transcript:
raise Exception("Transcript {transcript_id} not found") raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id) tlogger = logger.bind(transcript_id=transcript.id)
return await func(self, transcript=transcript, logger=tlogger, **kwargs) try:
return await func(transcript=transcript, logger=tlogger, **kwargs)
except Exception as exc:
tlogger.error("Pipeline error", exc_info=exc)
raise
return wrapper return wrapper
@@ -162,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
"flush": "processing", "flush": "processing",
"error": "error", "error": "error",
} }
elif isinstance(self, PipelineMainDiarization): elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = { status_mapping = {
"push": "processing", "push": "processing",
"flush": "processing", "flush": "processing",
@@ -170,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
"ended": "ended", "ended": "ended",
} }
else: else:
raise Exception(f"Runner {self.__class__} is missing status mapping") # intermediate pipeline don't update status
return
# mutate to model status # mutate to model status
status = status_mapping.get(status) status = status_mapping.get(status)
@@ -308,9 +310,10 @@ class PipelineMainBase(PipelineRunner):
class PipelineMainLive(PipelineMainBase): class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None """
source_language: str = "en" Main pipeline for live streaming, attach to RTC connection
target_language: str = "en" Any long post process should be done in the post pipeline
"""
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
@@ -320,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
processors = [ processors = [
AudioFileWriterProcessor( AudioFileWriterProcessor(
path=transcript.audio_mp3_filename, path=transcript.audio_wav_filename,
on_duration=self.on_duration, on_duration=self.on_duration,
), ),
AudioChunkerProcessor(), AudioChunkerProcessor(),
@@ -329,17 +332,13 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor( # XXX move as a task
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor.as_threaded( AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename, audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename, waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform, on_waveform=self.on_waveform,
), ),
] ]
),
]
pipeline = Pipeline(*processors) pipeline = Pipeline(*processors)
pipeline.options = self pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:source_language", transcript.source_language)
@@ -374,28 +373,16 @@ class PipelineMainDiarization(PipelineMainBase):
# first processor diarization processor # first processor diarization processor
# XXX translation is lost when converting our data model to the processor model # XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript() transcript = await self.get_transcript()
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript) topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=transcript.id,
)
url = f"{settings.BASE_URL}{path}"
if transcript.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
audio_diarization_input = AudioDiarizationInput( audio_diarization_input = AudioDiarizationInput(
audio_url=url, audio_url=audio_url,
topics=topics, topics=topics,
) )
@@ -409,14 +396,60 @@ class PipelineMainDiarization(PipelineMainBase):
return pipeline return pipeline
class PipelineMainSummaries(PipelineMainBase): class PipelineMainFromTopics(PipelineMainBase):
"""
Pseudo class for generating a pipeline from topics
"""
def get_processors(self) -> list:
raise NotImplementedError
async def create(self) -> Pipeline:
self.prepare()
processors = self.get_processors()
pipeline = Pipeline(*processors)
pipeline.options = self
# get transcript
transcript = await self.get_transcript()
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
"""
Generate title from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
)
]
class PipelineMainFinalSummaries(PipelineMainFromTopics):
""" """
Generate summaries from the topics Generate summaries from the topics
""" """
async def create(self) -> Pipeline: def get_processors(self) -> list:
self.prepare() return [
pipeline = Pipeline(
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalLongSummaryProcessor.as_threaded( TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -427,32 +460,7 @@ class PipelineMainSummaries(PipelineMainBase):
), ),
] ]
), ),
) ]
pipeline.options = self
# get transcript
transcript = await self.get_transcript()
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Summaries pipeline created")
# push topics
topics = await self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
@shared_task
def task_pipeline_main_post(transcript_id: str):
logger.info(
"Starting main post pipeline",
transcript_id=transcript_id,
)
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()
@get_transcript @get_transcript
@@ -470,24 +478,26 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
import av import av
input_container = av.open(wav_filename) with av.open(wav_filename.as_posix()) as in_container:
output_container = av.open(mp3_filename, "w") in_stream = in_container.streams.audio[0]
input_audio_stream = input_container.streams.audio[0] with av.open(mp3_filename.as_posix(), "w") as out_container:
output_audio_stream = output_container.add_stream("mp3") out_stream = out_container.add_stream("mp3")
output_audio_stream.codec_context.set_parameters( for frame in in_container.decode(in_stream):
input_audio_stream.codec_context.parameters for packet in out_stream.encode(frame):
) out_container.mux(packet)
for packet in input_container.demux(input_audio_stream):
for frame in packet.decode(): # Delete the wav file
output_container.mux(frame) transcript.audio_wav_filename.unlink(missing_ok=True)
input_container.close()
output_container.close()
logger.info("Convert to mp3 done") logger.info("Convert to mp3 done")
@get_transcript @get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
logger.info("Starting upload mp3") logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip # If the audio mp3 is not available, just skip
@@ -497,27 +507,32 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
return return
# Upload to external storage and delete the file # Upload to external storage and delete the file
await transcripts_controller.move_to_storage(transcript) await transcripts_controller.move_mp3_to_storage(transcript)
await transcripts_controller.unlink_mp3(transcript)
logger.info("Upload mp3 done") logger.info("Upload mp3 done")
@get_transcript @get_transcript
@asynctask
async def pipeline_diarization(transcript: Transcript, logger: Logger): async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization") logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id) runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.start() await runner.run()
logger.info("Diarization done") logger.info("Diarization done")
@get_transcript @get_transcript
@asynctask async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
logger.info("Starting title and short summary")
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
await runner.run()
logger.info("Title and short summary done")
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger): async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries") logger.info("Starting summaries")
runner = PipelineMainSummaries(transcript_id=transcript.id) runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.start() await runner.run()
logger.info("Summaries done") logger.info("Summaries done")
@@ -528,29 +543,35 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_convert_to_mp3(transcript_id: str): async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id) await pipeline_convert_to_mp3(transcript_id=transcript_id)
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_upload_mp3(transcript_id: str): async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id) await pipeline_upload_mp3(transcript_id=transcript_id)
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_diarization(transcript_id: str): async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id) await pipeline_diarization(transcript_id=transcript_id)
@shared_task @shared_task
@asynctask @asynctask
async def task_pipeline_summaries(transcript_id: str): async def task_pipeline_title_and_short_summary(*, transcript_id: str):
await pipeline_summaries(transcript_id) await pipeline_title_and_short_summary(transcript_id=transcript_id)
def pipeline_post(transcript_id: str): @shared_task
@asynctask
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
def pipeline_post(*, transcript_id: str):
""" """
Run the post pipeline Run the post pipeline
""" """
@@ -559,6 +580,15 @@ def pipeline_post(transcript_id: str):
| task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id)
) )
chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) chain_title_preview = task_pipeline_title_and_short_summary.si(
chain = chain_mp3_and_diarize | chain_summary transcript_id=transcript_id
)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
)
chain.delay() chain.delay()

View File

@@ -119,8 +119,7 @@ class PipelineRunner(BaseModel):
self._logger.exception("Runner error") self._logger.exception("Runner error")
await self._set_status("error") await self._set_status("error")
self._ev_done.set() self._ev_done.set()
if self.on_ended: raise
await self.on_ended()
async def cmd_push(self, data): async def cmd_push(self, data):
if self._is_first_push: if self._is_first_push:

View File

@@ -54,7 +54,7 @@ class Settings(BaseSettings):
TRANSCRIPT_MODAL_API_KEY: str | None = None TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage # Audio transcription storage
TRANSCRIPT_STORAGE_BACKEND: str = "aws" TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS # Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Transcript MP3 storage
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
# LLM # LLM
# available backend: openai, modal, oobabooga # available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga" LLM_BACKEND: str = "oobabooga"