mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing
This commit is contained in:
35
server/migrations/versions/f819277e5169_audio_location.py
Normal file
35
server/migrations/versions/f819277e5169_audio_location.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""audio_location
|
||||
|
||||
Revision ID: f819277e5169
|
||||
Revises: 4814901632bc
|
||||
Create Date: 2023-11-16 10:29:09.351664
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f819277e5169"
|
||||
down_revision: Union[str, None] = "4814901632bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"transcript",
|
||||
sa.Column(
|
||||
"audio_location", sa.String(), server_default="local", nullable=False
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("transcript", "audio_location")
|
||||
# ### end Alembic commands ###
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import Storage
|
||||
|
||||
transcripts = sqlalchemy.Table(
|
||||
"transcript",
|
||||
@@ -28,6 +29,12 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column(
|
||||
"audio_location",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="local",
|
||||
),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
@@ -39,15 +46,22 @@ transcripts = sqlalchemy.Table(
|
||||
)
|
||||
|
||||
|
||||
def generate_uuid4():
|
||||
def generate_uuid4() -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def generate_transcript_name():
|
||||
def generate_transcript_name() -> str:
|
||||
now = datetime.utcnow()
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
def get_storage() -> Storage:
|
||||
return Storage.get_instance(
|
||||
name=settings.TRANSCRIPT_STORAGE_BACKEND,
|
||||
settings_prefix="TRANSCRIPT_STORAGE_",
|
||||
)
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[float]
|
||||
|
||||
@@ -114,6 +128,7 @@ class Transcript(BaseModel):
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
share_mode: Literal["private", "semi-private", "public"] = "private"
|
||||
audio_location: str = "local"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
@@ -140,6 +155,10 @@ class Transcript(BaseModel):
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_wav_filename(self):
|
||||
return self.data_path / "audio.wav"
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
@@ -148,6 +167,10 @@ class Transcript(BaseModel):
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def storage_audio_path(self):
|
||||
return f"{self.id}/audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
@@ -160,6 +183,40 @@ class Transcript(BaseModel):
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
async def get_audio_url(self) -> str:
|
||||
if self.audio_location == "local":
|
||||
return self._generate_local_audio_link()
|
||||
elif self.audio_location == "storage":
|
||||
return await self._generate_storage_audio_link()
|
||||
raise Exception(f"Unknown audio location {self.audio_location}")
|
||||
|
||||
async def _generate_storage_audio_link(self) -> str:
|
||||
return await get_storage().get_file_url(self.storage_audio_path)
|
||||
|
||||
def _generate_local_audio_link(self) -> str:
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
# from the diarization processor
|
||||
from datetime import timedelta
|
||||
|
||||
from reflector.app import app
|
||||
from reflector.views.transcripts import create_access_token
|
||||
|
||||
path = app.url_path_for(
|
||||
"transcript_get_audio_mp3",
|
||||
transcript_id=self.id,
|
||||
)
|
||||
url = f"{settings.BASE_URL}{path}"
|
||||
if self.user_id:
|
||||
# we pass token only if the user_id is set
|
||||
# otherwise, the audio is public
|
||||
token = create_access_token(
|
||||
{"sub": self.user_id},
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
url += f"?token={token}"
|
||||
return url
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(
|
||||
@@ -336,5 +393,22 @@ class TranscriptController:
|
||||
transcript.upsert_topic(topic)
|
||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||
|
||||
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||
"""
|
||||
Move mp3 file to storage
|
||||
"""
|
||||
|
||||
# store the audio on external storage
|
||||
await get_storage().put_file(
|
||||
transcript.storage_audio_path,
|
||||
transcript.audio_mp3_filename.read_bytes(),
|
||||
)
|
||||
|
||||
# indicate on the transcript that the audio is now on storage
|
||||
await self.update(transcript, {"audio_location": "storage"})
|
||||
|
||||
# unlink the local file
|
||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||
|
||||
|
||||
transcripts_controller = TranscriptController()
|
||||
|
||||
@@ -12,13 +12,11 @@ It is directly linked to our data model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from celery import shared_task
|
||||
from celery import chord, group, shared_task
|
||||
from pydantic import BaseModel
|
||||
from reflector.app import app
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptDuration,
|
||||
@@ -55,6 +53,22 @@ from reflector.processors.types import (
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
from structlog import BoundLogger as Logger
|
||||
|
||||
|
||||
def asynctask(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
coro = f(*args, **kwargs)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
return asyncio.run(coro)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def broadcast_to_sockets(func):
|
||||
@@ -75,6 +89,26 @@ def broadcast_to_sockets(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_transcript(func):
|
||||
"""
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
try:
|
||||
return await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
except Exception as exc:
|
||||
tlogger.error("Pipeline error", exc_info=exc)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class StrValue(BaseModel):
|
||||
value: str
|
||||
|
||||
@@ -99,6 +133,19 @@ class PipelineMainBase(PipelineRunner):
|
||||
raise Exception("Transcript not found")
|
||||
return result
|
||||
|
||||
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
|
||||
return [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
async with self._lock:
|
||||
@@ -116,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
|
||||
"flush": "processing",
|
||||
"error": "error",
|
||||
}
|
||||
elif isinstance(self, PipelineMainDiarization):
|
||||
elif isinstance(self, PipelineMainFinalSummaries):
|
||||
status_mapping = {
|
||||
"push": "processing",
|
||||
"flush": "processing",
|
||||
@@ -124,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
|
||||
"ended": "ended",
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Runner {self.__class__} is missing status mapping")
|
||||
# intermediate pipeline don't update status
|
||||
return
|
||||
|
||||
# mutate to model status
|
||||
status = status_mapping.get(status)
|
||||
@@ -262,9 +310,10 @@ class PipelineMainBase(PipelineRunner):
|
||||
|
||||
|
||||
class PipelineMainLive(PipelineMainBase):
|
||||
audio_filename: Path | None = None
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
"""
|
||||
Main pipeline for live streaming, attach to RTC connection
|
||||
Any long post process should be done in the post pipeline
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
@@ -274,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
|
||||
processors = [
|
||||
AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
path=transcript.audio_wav_filename,
|
||||
on_duration=self.on_duration,
|
||||
),
|
||||
AudioChunkerProcessor(),
|
||||
@@ -283,26 +332,13 @@ class PipelineMainLive(PipelineMainBase):
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
AudioWaveformProcessor.as_threaded(
|
||||
audio_path=transcript.audio_mp3_filename,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main live created",
|
||||
transcript_id=self.transcript_id,
|
||||
)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -310,26 +346,106 @@ class PipelineMainLive(PipelineMainBase):
|
||||
# when the pipeline ends, connect to the post pipeline
|
||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
||||
pipeline_post(transcript_id=self.transcript_id)
|
||||
|
||||
|
||||
class PipelineMainDiarization(PipelineMainBase):
|
||||
"""
|
||||
Diarization is a long time process, so we do it in a separate pipeline
|
||||
When done, adjust the short and final summary
|
||||
Diarize the audio and update topics
|
||||
"""
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = []
|
||||
if settings.DIARIZATION_ENABLED:
|
||||
processors += [
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
]
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
pipeline.options = self
|
||||
|
||||
processors += [
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
# diarization works only if the file is uploaded to an external storage
|
||||
if transcript.audio_location == "local":
|
||||
pipeline.logger.info("Audio is local, skipping diarization")
|
||||
return
|
||||
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
audio_url = await transcript.get_audio_url()
|
||||
audio_diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url,
|
||||
topics=topics,
|
||||
)
|
||||
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
# to let the start just do one job.
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Diarization pipeline created")
|
||||
self.push(audio_diarization_input)
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainFromTopics(PipelineMainBase):
|
||||
"""
|
||||
Pseudo class for generating a pipeline from topics
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
|
||||
# create pipeline
|
||||
processors = self.get_processors()
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
|
||||
|
||||
# push topics
|
||||
topics = self.get_transcript_topics(transcript)
|
||||
for topic in topics:
|
||||
self.push(topic)
|
||||
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
|
||||
"""
|
||||
Generate title from the topics
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
return [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class PipelineMainFinalSummaries(PipelineMainFromTopics):
|
||||
"""
|
||||
Generate summaries from the topics
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
return [
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
@@ -341,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
]
|
||||
),
|
||||
]
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.options = self
|
||||
|
||||
# now let's start the pipeline by pushing information to the
|
||||
# first processor diarization processor
|
||||
# XXX translation is lost when converting our data model to the processor model
|
||||
transcript = await self.get_transcript()
|
||||
topics = [
|
||||
TitleSummaryWithIdProcessorType(
|
||||
id=topic.id,
|
||||
title=topic.title,
|
||||
summary=topic.summary,
|
||||
timestamp=topic.timestamp,
|
||||
duration=topic.duration,
|
||||
transcript=TranscriptProcessorType(words=topic.words),
|
||||
)
|
||||
for topic in transcript.topics
|
||||
|
||||
class PipelineMainWaveform(PipelineMainFromTopics):
|
||||
"""
|
||||
Generate waveform
|
||||
"""
|
||||
|
||||
def get_processors(self) -> list:
|
||||
return [
|
||||
AudioWaveformProcessor.as_threaded(
|
||||
audio_path=self._transcript.audio_wav_filename,
|
||||
waveform_path=self._transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
),
|
||||
]
|
||||
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
# from the diarization processor
|
||||
from reflector.views.transcripts import create_access_token
|
||||
|
||||
path = app.url_path_for(
|
||||
"transcript_get_audio_mp3",
|
||||
transcript_id=transcript.id,
|
||||
)
|
||||
url = f"{settings.BASE_URL}{path}"
|
||||
if transcript.user_id:
|
||||
# we pass token only if the user_id is set
|
||||
# otherwise, the audio is public
|
||||
token = create_access_token(
|
||||
{"sub": transcript.user_id},
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
url += f"?token={token}"
|
||||
audio_diarization_input = AudioDiarizationInput(
|
||||
audio_url=url,
|
||||
topics=topics,
|
||||
)
|
||||
@get_transcript
|
||||
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting waveform")
|
||||
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Waveform done")
|
||||
|
||||
# as tempting to use pipeline.push, prefer to use the runner
|
||||
# to let the start just do one job.
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info(
|
||||
"Pipeline main post created", transcript_id=self.transcript_id
|
||||
)
|
||||
self.push(audio_diarization_input)
|
||||
self.flush()
|
||||
|
||||
return pipeline
|
||||
@get_transcript
|
||||
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting convert to mp3")
|
||||
|
||||
# If the audio wav is not available, just skip
|
||||
wav_filename = transcript.audio_wav_filename
|
||||
if not wav_filename.exists():
|
||||
logger.warning("Wav file not found, may be already converted")
|
||||
return
|
||||
|
||||
# Convert to mp3
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
|
||||
import av
|
||||
|
||||
with av.open(wav_filename.as_posix()) as in_container:
|
||||
in_stream = in_container.streams.audio[0]
|
||||
with av.open(mp3_filename.as_posix(), "w") as out_container:
|
||||
out_stream = out_container.add_stream("mp3")
|
||||
for frame in in_container.decode(in_stream):
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Delete the wav file
|
||||
transcript.audio_wav_filename.unlink(missing_ok=True)
|
||||
|
||||
logger.info("Convert to mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||
logger.info("No storage backend configured, skipping mp3 upload")
|
||||
return
|
||||
|
||||
logger.info("Starting upload mp3")
|
||||
|
||||
# If the audio mp3 is not available, just skip
|
||||
mp3_filename = transcript.audio_mp3_filename
|
||||
if not mp3_filename.exists():
|
||||
logger.warning("Mp3 file not found, may be already uploaded")
|
||||
return
|
||||
|
||||
# Upload to external storage and delete the file
|
||||
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||
|
||||
logger.info("Upload mp3 done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting diarization")
|
||||
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Diarization done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting title and short summary")
|
||||
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Title and short summary done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting summaries")
|
||||
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||
await runner.run()
|
||||
logger.info("Summaries done")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Celery tasks that can be called from the API
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@shared_task
|
||||
def task_pipeline_main_post(transcript_id: str):
|
||||
logger.info(
|
||||
"Starting main post pipeline",
|
||||
transcript_id=transcript_id,
|
||||
@asynctask
|
||||
async def task_pipeline_waveform(*, transcript_id: str):
|
||||
await pipeline_waveform(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_diarization(*, transcript_id: str):
|
||||
await pipeline_diarization(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
|
||||
await pipeline_title_and_short_summary(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||
await pipeline_summaries(transcript_id=transcript_id)
|
||||
|
||||
|
||||
def pipeline_post(*, transcript_id: str):
|
||||
"""
|
||||
Run the post pipeline
|
||||
"""
|
||||
chain_mp3_and_diarize = (
|
||||
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
)
|
||||
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
||||
runner.start_sync()
|
||||
chain_title_preview = task_pipeline_title_and_short_summary.si(
|
||||
transcript_id=transcript_id
|
||||
)
|
||||
chain_final_summaries = task_pipeline_final_summaries.si(
|
||||
transcript_id=transcript_id
|
||||
)
|
||||
|
||||
chain = chord(
|
||||
group(chain_mp3_and_diarize, chain_title_preview),
|
||||
chain_final_summaries,
|
||||
)
|
||||
chain.delay()
|
||||
|
||||
@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
|
||||
if not self.pipeline:
|
||||
self.pipeline = await self.create()
|
||||
|
||||
if not self.pipeline:
|
||||
# no pipeline created in create, just finish it then.
|
||||
await self._set_status("ended")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
return
|
||||
|
||||
# start the loop
|
||||
await self._set_status("started")
|
||||
while not self._ev_done.is_set():
|
||||
@@ -119,8 +127,7 @@ class PipelineRunner(BaseModel):
|
||||
self._logger.exception("Runner error")
|
||||
await self._set_status("error")
|
||||
self._ev_done.set()
|
||||
if self.on_ended:
|
||||
await self.on_ended()
|
||||
raise
|
||||
|
||||
async def cmd_push(self, data):
|
||||
if self._is_first_push:
|
||||
|
||||
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Audio transcription storage
|
||||
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
|
||||
TRANSCRIPT_STORAGE_BACKEND: str | None = None
|
||||
|
||||
# Storage configuration for AWS
|
||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
|
||||
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# Transcript MP3 storage
|
||||
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
|
||||
|
||||
# LLM
|
||||
# available backend: openai, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from pydantic import BaseModel
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class FileResult(BaseModel):
|
||||
@@ -17,7 +18,7 @@ class Storage:
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name, settings_prefix=""):
|
||||
def get_instance(cls, name: str, settings_prefix: str = ""):
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.storage.storage_{name}"
|
||||
importlib.import_module(module_name)
|
||||
@@ -45,3 +46,9 @@ class Storage:
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_file_url(self, filename: str) -> str:
|
||||
return await self._get_file_url(filename)
|
||||
|
||||
async def _get_file_url(self, filename: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import aioboto3
|
||||
from reflector.storage.base import Storage, FileResult
|
||||
from reflector.logger import logger
|
||||
from reflector.storage.base import FileResult, Storage
|
||||
|
||||
|
||||
class AwsStorage(Storage):
|
||||
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
|
||||
Body=data,
|
||||
)
|
||||
|
||||
async def _get_file_url(self, filename: str) -> FileResult:
|
||||
bucket = self.aws_bucket_name
|
||||
folder = self.aws_folder
|
||||
s3filename = f"{folder}/{filename}" if folder else filename
|
||||
async with self.session.client("s3") as client:
|
||||
presigned_url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": s3filename},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
return FileResult(
|
||||
filename=filename,
|
||||
url=presigned_url,
|
||||
)
|
||||
return presigned_url
|
||||
|
||||
async def _delete_file(self, filename: str):
|
||||
bucket = self.aws_bucket_name
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import httpx
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
@@ -245,6 +247,42 @@ async def transcript_get_audio_mp3(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
# proxy S3 file, to prevent issue with CORS
|
||||
url = await transcript.get_audio_url()
|
||||
headers = {}
|
||||
|
||||
copy_headers = ["range", "accept-encoding"]
|
||||
for header in copy_headers:
|
||||
if header in request.headers:
|
||||
headers[header] = request.headers[header]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(request.method, url, headers=headers)
|
||||
return Response(
|
||||
content=resp.content,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
)
|
||||
|
||||
if transcript.audio_location == "storage":
|
||||
# proxy S3 file, to prevent issue with CORS
|
||||
url = await transcript.get_audio_url()
|
||||
headers = {}
|
||||
|
||||
copy_headers = ["range", "accept-encoding"]
|
||||
for header in copy_headers:
|
||||
if header in request.headers:
|
||||
headers[header] = request.headers[header]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(request.method, url, headers=headers)
|
||||
return Response(
|
||||
content=resp.content,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
)
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
|
||||
@@ -269,8 +307,8 @@ async def transcript_get_audio_waveform(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
if not transcript.audio_waveform_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from unittest.mock import patch
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -7,7 +8,6 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def setup_database():
|
||||
from reflector.settings import settings
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as f:
|
||||
settings.DATABASE_URL = f"sqlite:///{f.name}"
|
||||
@@ -103,6 +103,25 @@ async def dummy_llm():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_storage():
|
||||
from reflector.storage.base import Storage
|
||||
|
||||
class DummyStorage(Storage):
|
||||
async def _put_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def _delete_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def _get_file_url(self, *args, **kwargs):
|
||||
return "http://fake_server/audio.mp3"
|
||||
|
||||
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
||||
mock_storage.return_value = DummyStorage()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nltk():
|
||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||
@@ -133,4 +152,17 @@ def celery_enable_logging():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_config():
|
||||
return {"broker_url": "memory://", "result_backend": "rpc"}
|
||||
with NamedTemporaryFile() as f:
|
||||
yield {
|
||||
"broker_url": "memory://",
|
||||
"result_backend": f"db+sqlite:///{f.name}",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fake_mp3_upload():
|
||||
with patch(
|
||||
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
|
||||
) as mock_move:
|
||||
mock_move.return_value = True
|
||||
yield
|
||||
|
||||
@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
|
||||
dummy_transcript,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
@@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
dummy_transcript,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
appserver,
|
||||
sentence_tokenize,
|
||||
|
||||
Reference in New Issue
Block a user