mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing
This commit is contained in:
35
server/migrations/versions/f819277e5169_audio_location.py
Normal file
35
server/migrations/versions/f819277e5169_audio_location.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""audio_location
|
||||||
|
|
||||||
|
Revision ID: f819277e5169
|
||||||
|
Revises: 4814901632bc
|
||||||
|
Create Date: 2023-11-16 10:29:09.351664
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "f819277e5169"
|
||||||
|
down_revision: Union[str, None] = "4814901632bc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column(
|
||||||
|
"transcript",
|
||||||
|
sa.Column(
|
||||||
|
"audio_location", sa.String(), server_default="local", nullable=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("transcript", "audio_location")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
|||||||
from reflector.db import database, metadata
|
from reflector.db import database, metadata
|
||||||
from reflector.processors.types import Word as ProcessorWord
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
from reflector.storage import Storage
|
||||||
|
|
||||||
transcripts = sqlalchemy.Table(
|
transcripts = sqlalchemy.Table(
|
||||||
"transcript",
|
"transcript",
|
||||||
@@ -28,6 +29,12 @@ transcripts = sqlalchemy.Table(
|
|||||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"audio_location",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="local",
|
||||||
|
),
|
||||||
# with user attached, optional
|
# with user attached, optional
|
||||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||||
sqlalchemy.Column(
|
sqlalchemy.Column(
|
||||||
@@ -39,15 +46,22 @@ transcripts = sqlalchemy.Table(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_uuid4():
|
def generate_uuid4() -> str:
|
||||||
return str(uuid4())
|
return str(uuid4())
|
||||||
|
|
||||||
|
|
||||||
def generate_transcript_name():
|
def generate_transcript_name() -> str:
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage() -> Storage:
|
||||||
|
return Storage.get_instance(
|
||||||
|
name=settings.TRANSCRIPT_STORAGE_BACKEND,
|
||||||
|
settings_prefix="TRANSCRIPT_STORAGE_",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AudioWaveform(BaseModel):
|
class AudioWaveform(BaseModel):
|
||||||
data: list[float]
|
data: list[float]
|
||||||
|
|
||||||
@@ -114,6 +128,7 @@ class Transcript(BaseModel):
|
|||||||
source_language: str = "en"
|
source_language: str = "en"
|
||||||
target_language: str = "en"
|
target_language: str = "en"
|
||||||
share_mode: Literal["private", "semi-private", "public"] = "private"
|
share_mode: Literal["private", "semi-private", "public"] = "private"
|
||||||
|
audio_location: str = "local"
|
||||||
|
|
||||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||||
@@ -140,6 +155,10 @@ class Transcript(BaseModel):
|
|||||||
def data_path(self):
|
def data_path(self):
|
||||||
return Path(settings.DATA_DIR) / self.id
|
return Path(settings.DATA_DIR) / self.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_wav_filename(self):
|
||||||
|
return self.data_path / "audio.wav"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_mp3_filename(self):
|
def audio_mp3_filename(self):
|
||||||
return self.data_path / "audio.mp3"
|
return self.data_path / "audio.mp3"
|
||||||
@@ -148,6 +167,10 @@ class Transcript(BaseModel):
|
|||||||
def audio_waveform_filename(self):
|
def audio_waveform_filename(self):
|
||||||
return self.data_path / "audio.json"
|
return self.data_path / "audio.json"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage_audio_path(self):
|
||||||
|
return f"{self.id}/audio.mp3"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_waveform(self):
|
def audio_waveform(self):
|
||||||
try:
|
try:
|
||||||
@@ -160,6 +183,40 @@ class Transcript(BaseModel):
|
|||||||
|
|
||||||
return AudioWaveform(data=data)
|
return AudioWaveform(data=data)
|
||||||
|
|
||||||
|
async def get_audio_url(self) -> str:
|
||||||
|
if self.audio_location == "local":
|
||||||
|
return self._generate_local_audio_link()
|
||||||
|
elif self.audio_location == "storage":
|
||||||
|
return await self._generate_storage_audio_link()
|
||||||
|
raise Exception(f"Unknown audio location {self.audio_location}")
|
||||||
|
|
||||||
|
async def _generate_storage_audio_link(self) -> str:
|
||||||
|
return await get_storage().get_file_url(self.storage_audio_path)
|
||||||
|
|
||||||
|
def _generate_local_audio_link(self) -> str:
|
||||||
|
# we need to create an url to be used for diarization
|
||||||
|
# we can't use the audio_mp3_filename because it's not accessible
|
||||||
|
# from the diarization processor
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.views.transcripts import create_access_token
|
||||||
|
|
||||||
|
path = app.url_path_for(
|
||||||
|
"transcript_get_audio_mp3",
|
||||||
|
transcript_id=self.id,
|
||||||
|
)
|
||||||
|
url = f"{settings.BASE_URL}{path}"
|
||||||
|
if self.user_id:
|
||||||
|
# we pass token only if the user_id is set
|
||||||
|
# otherwise, the audio is public
|
||||||
|
token = create_access_token(
|
||||||
|
{"sub": self.user_id},
|
||||||
|
expires_delta=timedelta(minutes=15),
|
||||||
|
)
|
||||||
|
url += f"?token={token}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
class TranscriptController:
|
class TranscriptController:
|
||||||
async def get_all(
|
async def get_all(
|
||||||
@@ -336,5 +393,22 @@ class TranscriptController:
|
|||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await self.update(transcript, {"topics": transcript.topics_dump()})
|
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||||
|
|
||||||
|
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||||
|
"""
|
||||||
|
Move mp3 file to storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
# store the audio on external storage
|
||||||
|
await get_storage().put_file(
|
||||||
|
transcript.storage_audio_path,
|
||||||
|
transcript.audio_mp3_filename.read_bytes(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# indicate on the transcript that the audio is now on storage
|
||||||
|
await self.update(transcript, {"audio_location": "storage"})
|
||||||
|
|
||||||
|
# unlink the local file
|
||||||
|
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
transcripts_controller = TranscriptController()
|
transcripts_controller = TranscriptController()
|
||||||
|
|||||||
@@ -12,13 +12,11 @@ It is directly linked to our data model.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from celery import shared_task
|
from celery import chord, group, shared_task
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from reflector.app import app
|
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
Transcript,
|
Transcript,
|
||||||
TranscriptDuration,
|
TranscriptDuration,
|
||||||
@@ -55,6 +53,22 @@ from reflector.processors.types import (
|
|||||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
|
from structlog import BoundLogger as Logger
|
||||||
|
|
||||||
|
|
||||||
|
def asynctask(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
coro = f(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
if loop and loop.is_running():
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def broadcast_to_sockets(func):
|
def broadcast_to_sockets(func):
|
||||||
@@ -75,6 +89,26 @@ def broadcast_to_sockets(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_transcript(func):
|
||||||
|
"""
|
||||||
|
Decorator to fetch the transcript from the database from the first argument
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper(**kwargs):
|
||||||
|
transcript_id = kwargs.pop("transcript_id")
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise Exception("Transcript {transcript_id} not found")
|
||||||
|
tlogger = logger.bind(transcript_id=transcript.id)
|
||||||
|
try:
|
||||||
|
return await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
tlogger.error("Pipeline error", exc_info=exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class StrValue(BaseModel):
|
class StrValue(BaseModel):
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
@@ -99,6 +133,19 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
|
||||||
|
return [
|
||||||
|
TitleSummaryWithIdProcessorType(
|
||||||
|
id=topic.id,
|
||||||
|
title=topic.title,
|
||||||
|
summary=topic.summary,
|
||||||
|
timestamp=topic.timestamp,
|
||||||
|
duration=topic.duration,
|
||||||
|
transcript=TranscriptProcessorType(words=topic.words),
|
||||||
|
)
|
||||||
|
for topic in transcript.topics
|
||||||
|
]
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self):
|
async def transaction(self):
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
@@ -116,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
"flush": "processing",
|
"flush": "processing",
|
||||||
"error": "error",
|
"error": "error",
|
||||||
}
|
}
|
||||||
elif isinstance(self, PipelineMainDiarization):
|
elif isinstance(self, PipelineMainFinalSummaries):
|
||||||
status_mapping = {
|
status_mapping = {
|
||||||
"push": "processing",
|
"push": "processing",
|
||||||
"flush": "processing",
|
"flush": "processing",
|
||||||
@@ -124,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
"ended": "ended",
|
"ended": "ended",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Runner {self.__class__} is missing status mapping")
|
# intermediate pipeline don't update status
|
||||||
|
return
|
||||||
|
|
||||||
# mutate to model status
|
# mutate to model status
|
||||||
status = status_mapping.get(status)
|
status = status_mapping.get(status)
|
||||||
@@ -262,9 +310,10 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
|
|
||||||
|
|
||||||
class PipelineMainLive(PipelineMainBase):
|
class PipelineMainLive(PipelineMainBase):
|
||||||
audio_filename: Path | None = None
|
"""
|
||||||
source_language: str = "en"
|
Main pipeline for live streaming, attach to RTC connection
|
||||||
target_language: str = "en"
|
Any long post process should be done in the post pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
@@ -274,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
AudioFileWriterProcessor(
|
AudioFileWriterProcessor(
|
||||||
path=transcript.audio_mp3_filename,
|
path=transcript.audio_wav_filename,
|
||||||
on_duration=self.on_duration,
|
on_duration=self.on_duration,
|
||||||
),
|
),
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
@@ -283,26 +332,13 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||||
BroadcastProcessor(
|
|
||||||
processors=[
|
|
||||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
|
||||||
AudioWaveformProcessor.as_threaded(
|
|
||||||
audio_path=transcript.audio_mp3_filename,
|
|
||||||
waveform_path=transcript.audio_waveform_filename,
|
|
||||||
on_waveform=self.on_waveform,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
pipeline = Pipeline(*processors)
|
pipeline = Pipeline(*processors)
|
||||||
pipeline.options = self
|
pipeline.options = self
|
||||||
pipeline.set_pref("audio:source_language", transcript.source_language)
|
pipeline.set_pref("audio:source_language", transcript.source_language)
|
||||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||||
pipeline.logger.bind(transcript_id=transcript.id)
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
pipeline.logger.info(
|
pipeline.logger.info("Pipeline main live created")
|
||||||
"Pipeline main live created",
|
|
||||||
transcript_id=self.transcript_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@@ -310,26 +346,106 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
# when the pipeline ends, connect to the post pipeline
|
# when the pipeline ends, connect to the post pipeline
|
||||||
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
|
||||||
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
|
||||||
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
|
pipeline_post(transcript_id=self.transcript_id)
|
||||||
|
|
||||||
|
|
||||||
class PipelineMainDiarization(PipelineMainBase):
|
class PipelineMainDiarization(PipelineMainBase):
|
||||||
"""
|
"""
|
||||||
Diarization is a long time process, so we do it in a separate pipeline
|
Diarize the audio and update topics
|
||||||
When done, adjust the short and final summary
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
self.prepare()
|
||||||
processors = []
|
pipeline = Pipeline(
|
||||||
if settings.DIARIZATION_ENABLED:
|
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||||
processors += [
|
)
|
||||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
pipeline.options = self
|
||||||
]
|
|
||||||
|
|
||||||
processors += [
|
# now let's start the pipeline by pushing information to the
|
||||||
|
# first processor diarization processor
|
||||||
|
# XXX translation is lost when converting our data model to the processor model
|
||||||
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
# diarization works only if the file is uploaded to an external storage
|
||||||
|
if transcript.audio_location == "local":
|
||||||
|
pipeline.logger.info("Audio is local, skipping diarization")
|
||||||
|
return
|
||||||
|
|
||||||
|
topics = self.get_transcript_topics(transcript)
|
||||||
|
audio_url = await transcript.get_audio_url()
|
||||||
|
audio_diarization_input = AudioDiarizationInput(
|
||||||
|
audio_url=audio_url,
|
||||||
|
topics=topics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# as tempting to use pipeline.push, prefer to use the runner
|
||||||
|
# to let the start just do one job.
|
||||||
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
|
pipeline.logger.info("Diarization pipeline created")
|
||||||
|
self.push(audio_diarization_input)
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainFromTopics(PipelineMainBase):
|
||||||
|
"""
|
||||||
|
Pseudo class for generating a pipeline from topics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_processors(self) -> list:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def create(self) -> Pipeline:
|
||||||
|
self.prepare()
|
||||||
|
|
||||||
|
# get transcript
|
||||||
|
self._transcript = transcript = await self.get_transcript()
|
||||||
|
|
||||||
|
# create pipeline
|
||||||
|
processors = self.get_processors()
|
||||||
|
pipeline = Pipeline(*processors)
|
||||||
|
pipeline.options = self
|
||||||
|
pipeline.logger.bind(transcript_id=transcript.id)
|
||||||
|
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
|
||||||
|
|
||||||
|
# push topics
|
||||||
|
topics = self.get_transcript_topics(transcript)
|
||||||
|
for topic in topics:
|
||||||
|
self.push(topic)
|
||||||
|
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
|
||||||
|
"""
|
||||||
|
Generate title from the topics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_processors(self) -> list:
|
||||||
|
return [
|
||||||
|
BroadcastProcessor(
|
||||||
|
processors=[
|
||||||
|
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||||
|
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||||
|
callback=self.on_short_summary
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMainFinalSummaries(PipelineMainFromTopics):
|
||||||
|
"""
|
||||||
|
Generate summaries from the topics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_processors(self) -> list:
|
||||||
|
return [
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
@@ -341,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
pipeline = Pipeline(*processors)
|
|
||||||
pipeline.options = self
|
|
||||||
|
|
||||||
# now let's start the pipeline by pushing information to the
|
|
||||||
# first processor diarization processor
|
class PipelineMainWaveform(PipelineMainFromTopics):
|
||||||
# XXX translation is lost when converting our data model to the processor model
|
"""
|
||||||
transcript = await self.get_transcript()
|
Generate waveform
|
||||||
topics = [
|
"""
|
||||||
TitleSummaryWithIdProcessorType(
|
|
||||||
id=topic.id,
|
def get_processors(self) -> list:
|
||||||
title=topic.title,
|
return [
|
||||||
summary=topic.summary,
|
AudioWaveformProcessor.as_threaded(
|
||||||
timestamp=topic.timestamp,
|
audio_path=self._transcript.audio_wav_filename,
|
||||||
duration=topic.duration,
|
waveform_path=self._transcript.audio_waveform_filename,
|
||||||
transcript=TranscriptProcessorType(words=topic.words),
|
on_waveform=self.on_waveform,
|
||||||
)
|
),
|
||||||
for topic in transcript.topics
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# we need to create an url to be used for diarization
|
|
||||||
# we can't use the audio_mp3_filename because it's not accessible
|
|
||||||
# from the diarization processor
|
|
||||||
from reflector.views.transcripts import create_access_token
|
|
||||||
|
|
||||||
path = app.url_path_for(
|
@get_transcript
|
||||||
"transcript_get_audio_mp3",
|
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||||
transcript_id=transcript.id,
|
logger.info("Starting waveform")
|
||||||
)
|
runner = PipelineMainWaveform(transcript_id=transcript.id)
|
||||||
url = f"{settings.BASE_URL}{path}"
|
await runner.run()
|
||||||
if transcript.user_id:
|
logger.info("Waveform done")
|
||||||
# we pass token only if the user_id is set
|
|
||||||
# otherwise, the audio is public
|
|
||||||
token = create_access_token(
|
|
||||||
{"sub": transcript.user_id},
|
|
||||||
expires_delta=timedelta(minutes=15),
|
|
||||||
)
|
|
||||||
url += f"?token={token}"
|
|
||||||
audio_diarization_input = AudioDiarizationInput(
|
|
||||||
audio_url=url,
|
|
||||||
topics=topics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# as tempting to use pipeline.push, prefer to use the runner
|
|
||||||
# to let the start just do one job.
|
|
||||||
pipeline.logger.bind(transcript_id=transcript.id)
|
|
||||||
pipeline.logger.info(
|
|
||||||
"Pipeline main post created", transcript_id=self.transcript_id
|
|
||||||
)
|
|
||||||
self.push(audio_diarization_input)
|
|
||||||
self.flush()
|
|
||||||
|
|
||||||
return pipeline
|
@get_transcript
|
||||||
|
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
||||||
|
logger.info("Starting convert to mp3")
|
||||||
|
|
||||||
|
# If the audio wav is not available, just skip
|
||||||
|
wav_filename = transcript.audio_wav_filename
|
||||||
|
if not wav_filename.exists():
|
||||||
|
logger.warning("Wav file not found, may be already converted")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert to mp3
|
||||||
|
mp3_filename = transcript.audio_mp3_filename
|
||||||
|
|
||||||
|
import av
|
||||||
|
|
||||||
|
with av.open(wav_filename.as_posix()) as in_container:
|
||||||
|
in_stream = in_container.streams.audio[0]
|
||||||
|
with av.open(mp3_filename.as_posix(), "w") as out_container:
|
||||||
|
out_stream = out_container.add_stream("mp3")
|
||||||
|
for frame in in_container.decode(in_stream):
|
||||||
|
for packet in out_stream.encode(frame):
|
||||||
|
out_container.mux(packet)
|
||||||
|
|
||||||
|
# Delete the wav file
|
||||||
|
transcript.audio_wav_filename.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
logger.info("Convert to mp3 done")
|
||||||
|
|
||||||
|
|
||||||
|
@get_transcript
|
||||||
|
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||||
|
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||||
|
logger.info("No storage backend configured, skipping mp3 upload")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting upload mp3")
|
||||||
|
|
||||||
|
# If the audio mp3 is not available, just skip
|
||||||
|
mp3_filename = transcript.audio_mp3_filename
|
||||||
|
if not mp3_filename.exists():
|
||||||
|
logger.warning("Mp3 file not found, may be already uploaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Upload to external storage and delete the file
|
||||||
|
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||||
|
|
||||||
|
logger.info("Upload mp3 done")
|
||||||
|
|
||||||
|
|
||||||
|
@get_transcript
|
||||||
|
async def pipeline_diarization(transcript: Transcript, logger: Logger):
|
||||||
|
logger.info("Starting diarization")
|
||||||
|
runner = PipelineMainDiarization(transcript_id=transcript.id)
|
||||||
|
await runner.run()
|
||||||
|
logger.info("Diarization done")
|
||||||
|
|
||||||
|
|
||||||
|
@get_transcript
|
||||||
|
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
|
||||||
|
logger.info("Starting title and short summary")
|
||||||
|
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
|
||||||
|
await runner.run()
|
||||||
|
logger.info("Title and short summary done")
|
||||||
|
|
||||||
|
|
||||||
|
@get_transcript
|
||||||
|
async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||||
|
logger.info("Starting summaries")
|
||||||
|
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
|
||||||
|
await runner.run()
|
||||||
|
logger.info("Summaries done")
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Celery tasks that can be called from the API
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def task_pipeline_main_post(transcript_id: str):
|
@asynctask
|
||||||
logger.info(
|
async def task_pipeline_waveform(*, transcript_id: str):
|
||||||
"Starting main post pipeline",
|
await pipeline_waveform(transcript_id=transcript_id)
|
||||||
transcript_id=transcript_id,
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
||||||
|
await pipeline_convert_to_mp3(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||||
|
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_diarization(*, transcript_id: str):
|
||||||
|
await pipeline_diarization(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
|
||||||
|
await pipeline_title_and_short_summary(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_final_summaries(*, transcript_id: str):
|
||||||
|
await pipeline_summaries(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_post(*, transcript_id: str):
|
||||||
|
"""
|
||||||
|
Run the post pipeline
|
||||||
|
"""
|
||||||
|
chain_mp3_and_diarize = (
|
||||||
|
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||||
|
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||||
|
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||||
|
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||||
)
|
)
|
||||||
runner = PipelineMainDiarization(transcript_id=transcript_id)
|
chain_title_preview = task_pipeline_title_and_short_summary.si(
|
||||||
runner.start_sync()
|
transcript_id=transcript_id
|
||||||
|
)
|
||||||
|
chain_final_summaries = task_pipeline_final_summaries.si(
|
||||||
|
transcript_id=transcript_id
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = chord(
|
||||||
|
group(chain_mp3_and_diarize, chain_title_preview),
|
||||||
|
chain_final_summaries,
|
||||||
|
)
|
||||||
|
chain.delay()
|
||||||
|
|||||||
@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
|
|||||||
if not self.pipeline:
|
if not self.pipeline:
|
||||||
self.pipeline = await self.create()
|
self.pipeline = await self.create()
|
||||||
|
|
||||||
|
if not self.pipeline:
|
||||||
|
# no pipeline created in create, just finish it then.
|
||||||
|
await self._set_status("ended")
|
||||||
|
self._ev_done.set()
|
||||||
|
if self.on_ended:
|
||||||
|
await self.on_ended()
|
||||||
|
return
|
||||||
|
|
||||||
# start the loop
|
# start the loop
|
||||||
await self._set_status("started")
|
await self._set_status("started")
|
||||||
while not self._ev_done.is_set():
|
while not self._ev_done.is_set():
|
||||||
@@ -119,8 +127,7 @@ class PipelineRunner(BaseModel):
|
|||||||
self._logger.exception("Runner error")
|
self._logger.exception("Runner error")
|
||||||
await self._set_status("error")
|
await self._set_status("error")
|
||||||
self._ev_done.set()
|
self._ev_done.set()
|
||||||
if self.on_ended:
|
raise
|
||||||
await self.on_ended()
|
|
||||||
|
|
||||||
async def cmd_push(self, data):
|
async def cmd_push(self, data):
|
||||||
if self._is_first_push:
|
if self._is_first_push:
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
# Audio transcription storage
|
# Audio transcription storage
|
||||||
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
|
TRANSCRIPT_STORAGE_BACKEND: str | None = None
|
||||||
|
|
||||||
# Storage configuration for AWS
|
# Storage configuration for AWS
|
||||||
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
|
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
|
||||||
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
|
||||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||||
|
|
||||||
# Transcript MP3 storage
|
|
||||||
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
|
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
# available backend: openai, modal, oobabooga
|
# available backend: openai, modal, oobabooga
|
||||||
LLM_BACKEND: str = "oobabooga"
|
LLM_BACKEND: str = "oobabooga"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import importlib
|
|
||||||
|
|
||||||
|
|
||||||
class FileResult(BaseModel):
|
class FileResult(BaseModel):
|
||||||
@@ -17,7 +18,7 @@ class Storage:
|
|||||||
cls._registry[name] = kclass
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, name, settings_prefix=""):
|
def get_instance(cls, name: str, settings_prefix: str = ""):
|
||||||
if name not in cls._registry:
|
if name not in cls._registry:
|
||||||
module_name = f"reflector.storage.storage_{name}"
|
module_name = f"reflector.storage.storage_{name}"
|
||||||
importlib.import_module(module_name)
|
importlib.import_module(module_name)
|
||||||
@@ -45,3 +46,9 @@ class Storage:
|
|||||||
|
|
||||||
async def _delete_file(self, filename: str):
|
async def _delete_file(self, filename: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_file_url(self, filename: str) -> str:
|
||||||
|
return await self._get_file_url(filename)
|
||||||
|
|
||||||
|
async def _get_file_url(self, filename: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import aioboto3
|
import aioboto3
|
||||||
from reflector.storage.base import Storage, FileResult
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
|
from reflector.storage.base import FileResult, Storage
|
||||||
|
|
||||||
|
|
||||||
class AwsStorage(Storage):
|
class AwsStorage(Storage):
|
||||||
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
|
|||||||
Body=data,
|
Body=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _get_file_url(self, filename: str) -> FileResult:
|
||||||
|
bucket = self.aws_bucket_name
|
||||||
|
folder = self.aws_folder
|
||||||
|
s3filename = f"{folder}/{filename}" if folder else filename
|
||||||
|
async with self.session.client("s3") as client:
|
||||||
presigned_url = await client.generate_presigned_url(
|
presigned_url = await client.generate_presigned_url(
|
||||||
"get_object",
|
"get_object",
|
||||||
Params={"Bucket": bucket, "Key": s3filename},
|
Params={"Bucket": bucket, "Key": s3filename},
|
||||||
ExpiresIn=3600,
|
ExpiresIn=3600,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FileResult(
|
return presigned_url
|
||||||
filename=filename,
|
|
||||||
url=presigned_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _delete_file(self, filename: str):
|
async def _delete_file(self, filename: str):
|
||||||
bucket = self.aws_bucket_name
|
bucket = self.aws_bucket_name
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Annotated, Literal, Optional
|
from typing import Annotated, Literal, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
Request,
|
Request,
|
||||||
|
Response,
|
||||||
WebSocket,
|
WebSocket,
|
||||||
WebSocketDisconnect,
|
WebSocketDisconnect,
|
||||||
status,
|
status,
|
||||||
@@ -245,6 +247,42 @@ async def transcript_get_audio_mp3(
|
|||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if transcript.audio_location == "storage":
|
||||||
|
# proxy S3 file, to prevent issue with CORS
|
||||||
|
url = await transcript.get_audio_url()
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
copy_headers = ["range", "accept-encoding"]
|
||||||
|
for header in copy_headers:
|
||||||
|
if header in request.headers:
|
||||||
|
headers[header] = request.headers[header]
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.request(request.method, url, headers=headers)
|
||||||
|
return Response(
|
||||||
|
content=resp.content,
|
||||||
|
status_code=resp.status_code,
|
||||||
|
headers=resp.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if transcript.audio_location == "storage":
|
||||||
|
# proxy S3 file, to prevent issue with CORS
|
||||||
|
url = await transcript.get_audio_url()
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
copy_headers = ["range", "accept-encoding"]
|
||||||
|
for header in copy_headers:
|
||||||
|
if header in request.headers:
|
||||||
|
headers[header] = request.headers[header]
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.request(request.method, url, headers=headers)
|
||||||
|
return Response(
|
||||||
|
content=resp.content,
|
||||||
|
status_code=resp.status_code,
|
||||||
|
headers=resp.headers,
|
||||||
|
)
|
||||||
|
|
||||||
if not transcript.audio_mp3_filename.exists():
|
if not transcript.audio_mp3_filename.exists():
|
||||||
raise HTTPException(status_code=500, detail="Audio not found")
|
raise HTTPException(status_code=500, detail="Audio not found")
|
||||||
|
|
||||||
@@ -269,8 +307,8 @@ async def transcript_get_audio_waveform(
|
|||||||
transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not transcript.audio_mp3_filename.exists():
|
if not transcript.audio_waveform_filename.exists():
|
||||||
raise HTTPException(status_code=500, detail="Audio not found")
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
return transcript.audio_waveform
|
return transcript.audio_waveform
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -7,7 +8,6 @@ import pytest
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def setup_database():
|
async def setup_database():
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
|
|
||||||
with NamedTemporaryFile() as f:
|
with NamedTemporaryFile() as f:
|
||||||
settings.DATABASE_URL = f"sqlite:///{f.name}"
|
settings.DATABASE_URL = f"sqlite:///{f.name}"
|
||||||
@@ -103,6 +103,25 @@ async def dummy_llm():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dummy_storage():
|
||||||
|
from reflector.storage.base import Storage
|
||||||
|
|
||||||
|
class DummyStorage(Storage):
|
||||||
|
async def _put_file(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _delete_file(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _get_file_url(self, *args, **kwargs):
|
||||||
|
return "http://fake_server/audio.mp3"
|
||||||
|
|
||||||
|
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
||||||
|
mock_storage.return_value = DummyStorage()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nltk():
|
def nltk():
|
||||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||||
@@ -133,4 +152,17 @@ def celery_enable_logging():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def celery_config():
|
def celery_config():
|
||||||
return {"broker_url": "memory://", "result_backend": "rpc"}
|
with NamedTemporaryFile() as f:
|
||||||
|
yield {
|
||||||
|
"broker_url": "memory://",
|
||||||
|
"result_backend": f"db+sqlite:///{f.name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def fake_mp3_upload():
|
||||||
|
with patch(
|
||||||
|
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
|
||||||
|
) as mock_move:
|
||||||
|
mock_move.return_value = True
|
||||||
|
yield
|
||||||
|
|||||||
@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
|
dummy_storage,
|
||||||
|
fake_mp3_upload,
|
||||||
ensure_casing,
|
ensure_casing,
|
||||||
appserver,
|
appserver,
|
||||||
sentence_tokenize,
|
sentence_tokenize,
|
||||||
@@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
dummy_transcript,
|
dummy_transcript,
|
||||||
dummy_processors,
|
dummy_processors,
|
||||||
dummy_diarization,
|
dummy_diarization,
|
||||||
|
dummy_storage,
|
||||||
|
fake_mp3_upload,
|
||||||
ensure_casing,
|
ensure_casing,
|
||||||
appserver,
|
appserver,
|
||||||
sentence_tokenize,
|
sentence_tokenize,
|
||||||
|
|||||||
Reference in New Issue
Block a user