Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing

This commit is contained in:
Sara
2023-11-22 19:28:45 +01:00
10 changed files with 515 additions and 104 deletions

View File

@@ -0,0 +1,35 @@
"""audio_location
Revision ID: f819277e5169
Revises: 4814901632bc
Create Date: 2023-11-16 10:29:09.351664
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f819277e5169"
down_revision: Union[str, None] = "4814901632bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"audio_location", sa.String(), server_default="local", nullable=False
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "audio_location")
# ### end Alembic commands ###

View File

@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import Storage
transcripts = sqlalchemy.Table( transcripts = sqlalchemy.Table(
"transcript", "transcript",
@@ -28,6 +29,12 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional # with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column( sqlalchemy.Column(
@@ -39,15 +46,22 @@ transcripts = sqlalchemy.Table(
) )
def generate_uuid4(): def generate_uuid4() -> str:
return str(uuid4()) return str(uuid4())
def generate_transcript_name(): def generate_transcript_name() -> str:
now = datetime.utcnow() now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
def get_storage() -> Storage:
return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND,
settings_prefix="TRANSCRIPT_STORAGE_",
)
class AudioWaveform(BaseModel): class AudioWaveform(BaseModel):
data: list[float] data: list[float]
@@ -114,6 +128,7 @@ class Transcript(BaseModel):
source_language: str = "en" source_language: str = "en"
target_language: str = "en" target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private" share_mode: Literal["private", "semi-private", "public"] = "private"
audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump()) ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -140,6 +155,10 @@ class Transcript(BaseModel):
def data_path(self): def data_path(self):
return Path(settings.DATA_DIR) / self.id return Path(settings.DATA_DIR) / self.id
@property
def audio_wav_filename(self):
return self.data_path / "audio.wav"
@property @property
def audio_mp3_filename(self): def audio_mp3_filename(self):
return self.data_path / "audio.mp3" return self.data_path / "audio.mp3"
@@ -148,6 +167,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self): def audio_waveform_filename(self):
return self.data_path / "audio.json" return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property @property
def audio_waveform(self): def audio_waveform(self):
try: try:
@@ -160,6 +183,40 @@ class Transcript(BaseModel):
return AudioWaveform(data=data) return AudioWaveform(data=data)
async def get_audio_url(self) -> str:
if self.audio_location == "local":
return self._generate_local_audio_link()
elif self.audio_location == "storage":
return await self._generate_storage_audio_link()
raise Exception(f"Unknown audio location {self.audio_location}")
async def _generate_storage_audio_link(self) -> str:
return await get_storage().get_file_url(self.storage_audio_path)
def _generate_local_audio_link(self) -> str:
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=self.id,
)
url = f"{settings.BASE_URL}{path}"
if self.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": self.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
return url
class TranscriptController: class TranscriptController:
async def get_all( async def get_all(
@@ -336,5 +393,22 @@ class TranscriptController:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()}) await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
# store the audio on external storage
await get_storage().put_file(
transcript.storage_audio_path,
transcript.audio_mp3_filename.read_bytes(),
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -12,13 +12,11 @@ It is directly linked to our data model.
""" """
import asyncio import asyncio
import functools
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from celery import shared_task from celery import chord, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptDuration, TranscriptDuration,
@@ -55,6 +53,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import BoundLogger as Logger
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
coro = f(*args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper
def broadcast_to_sockets(func): def broadcast_to_sockets(func):
@@ -75,6 +89,26 @@ def broadcast_to_sockets(func):
return wrapper return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
return await func(transcript=transcript, logger=tlogger, **kwargs)
except Exception as exc:
tlogger.error("Pipeline error", exc_info=exc)
raise
return wrapper
class StrValue(BaseModel): class StrValue(BaseModel):
value: str value: str
@@ -99,6 +133,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
async with self._lock: async with self._lock:
@@ -116,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
"flush": "processing", "flush": "processing",
"error": "error", "error": "error",
} }
elif isinstance(self, PipelineMainDiarization): elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = { status_mapping = {
"push": "processing", "push": "processing",
"flush": "processing", "flush": "processing",
@@ -124,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
"ended": "ended", "ended": "ended",
} }
else: else:
raise Exception(f"Runner {self.__class__} is missing status mapping") # intermediate pipeline don't update status
return
# mutate to model status # mutate to model status
status = status_mapping.get(status) status = status_mapping.get(status)
@@ -262,9 +310,10 @@ class PipelineMainBase(PipelineRunner):
class PipelineMainLive(PipelineMainBase): class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None """
source_language: str = "en" Main pipeline for live streaming, attach to RTC connection
target_language: str = "en" Any long post process should be done in the post pipeline
"""
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
@@ -274,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
processors = [ processors = [
AudioFileWriterProcessor( AudioFileWriterProcessor(
path=transcript.audio_mp3_filename, path=transcript.audio_wav_filename,
on_duration=self.on_duration, on_duration=self.on_duration,
), ),
AudioChunkerProcessor(), AudioChunkerProcessor(),
@@ -283,26 +332,13 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
),
] ]
pipeline = Pipeline(*processors) pipeline = Pipeline(*processors)
pipeline.options = self pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info( pipeline.logger.info("Pipeline main live created")
"Pipeline main live created",
transcript_id=self.transcript_id,
)
return pipeline return pipeline
@@ -310,26 +346,106 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline # when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id) pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase): class PipelineMainDiarization(PipelineMainBase):
""" """
Diarization is a long time process, so we do it in a separate pipeline Diarize the audio and update topics
When done, adjust the short and final summary
""" """
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
self.prepare() self.prepare()
processors = [] pipeline = Pipeline(
if settings.DIARIZATION_ENABLED:
processors += [
AudioDiarizationAutoProcessor(callback=self.on_topic), AudioDiarizationAutoProcessor(callback=self.on_topic),
)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Diarization pipeline created")
self.push(audio_diarization_input)
self.flush()
return pipeline
class PipelineMainFromTopics(PipelineMainBase):
"""
Pseudo class for generating a pipeline from topics
"""
def get_processors(self) -> list:
raise NotImplementedError
async def create(self) -> Pipeline:
self.prepare()
# get transcript
self._transcript = transcript = await self.get_transcript()
# create pipeline
processors = self.get_processors()
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
"""
Generate title from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
)
] ]
processors += [
class PipelineMainFinalSummaries(PipelineMainFromTopics):
"""
Generate summaries from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalLongSummaryProcessor.as_threaded( TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -341,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase):
] ]
), ),
] ]
pipeline = Pipeline(*processors)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor class PipelineMainWaveform(PipelineMainFromTopics):
# XXX translation is lost when converting our data model to the processor model """
transcript = await self.get_transcript() Generate waveform
topics = [ """
TitleSummaryWithIdProcessorType(
id=topic.id, def get_processors(self) -> list:
title=topic.title, return [
summary=topic.summary, AudioWaveformProcessor.as_threaded(
timestamp=topic.timestamp, audio_path=self._transcript.audio_wav_filename,
duration=topic.duration, waveform_path=self._transcript.audio_waveform_filename,
transcript=TranscriptProcessorType(words=topic.words), on_waveform=self.on_waveform,
) ),
for topic in transcript.topics
] ]
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
path = app.url_path_for( @get_transcript
"transcript_get_audio_mp3", async def pipeline_waveform(transcript: Transcript, logger: Logger):
transcript_id=transcript.id, logger.info("Starting waveform")
) runner = PipelineMainWaveform(transcript_id=transcript.id)
url = f"{settings.BASE_URL}{path}" await runner.run()
if transcript.user_id: logger.info("Waveform done")
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_url=url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main post created", transcript_id=self.transcript_id
)
self.push(audio_diarization_input)
self.flush()
return pipeline @get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
wav_filename = transcript.audio_wav_filename
if not wav_filename.exists():
logger.warning("Wav file not found, may be already converted")
return
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container:
out_stream = out_container.add_stream("mp3")
for frame in in_container.decode(in_stream):
for packet in out_stream.encode(frame):
out_container.mux(packet)
# Delete the wav file
transcript.audio_wav_filename.unlink(missing_ok=True)
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
mp3_filename = transcript.audio_mp3_filename
if not mp3_filename.exists():
logger.warning("Mp3 file not found, may be already uploaded")
return
# Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done")
@get_transcript
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run()
logger.info("Diarization done")
@get_transcript
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
logger.info("Starting title and short summary")
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
await runner.run()
logger.info("Title and short summary done")
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run()
logger.info("Summaries done")
# ===================================================================
# Celery tasks that can be called from the API
# ===================================================================
@shared_task @shared_task
def task_pipeline_main_post(transcript_id: str): @asynctask
logger.info( async def task_pipeline_waveform(*, transcript_id: str):
"Starting main post pipeline", await pipeline_waveform(transcript_id=transcript_id)
transcript_id=transcript_id,
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
await pipeline_title_and_short_summary(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
def pipeline_post(*, transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_waveform.si(transcript_id=transcript_id)
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
) )
runner = PipelineMainDiarization(transcript_id=transcript_id) chain_title_preview = task_pipeline_title_and_short_summary.si(
runner.start_sync() transcript_id=transcript_id
)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
)
chain.delay()

View File

@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
if not self.pipeline: if not self.pipeline:
self.pipeline = await self.create() self.pipeline = await self.create()
if not self.pipeline:
# no pipeline created in create, just finish it then.
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
return
# start the loop # start the loop
await self._set_status("started") await self._set_status("started")
while not self._ev_done.is_set(): while not self._ev_done.is_set():
@@ -119,8 +127,7 @@ class PipelineRunner(BaseModel):
self._logger.exception("Runner error") self._logger.exception("Runner error")
await self._set_status("error") await self._set_status("error")
self._ev_done.set() self._ev_done.set()
if self.on_ended: raise
await self.on_ended()
async def cmd_push(self, data): async def cmd_push(self, data):
if self._is_first_push: if self._is_first_push:

View File

@@ -54,7 +54,7 @@ class Settings(BaseSettings):
TRANSCRIPT_MODAL_API_KEY: str | None = None TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage # Audio transcription storage
TRANSCRIPT_STORAGE_BACKEND: str = "aws" TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS # Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Transcript MP3 storage
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
# LLM # LLM
# available backend: openai, modal, oobabooga # available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga" LLM_BACKEND: str = "oobabooga"

View File

@@ -1,6 +1,7 @@
import importlib
from pydantic import BaseModel from pydantic import BaseModel
from reflector.settings import settings from reflector.settings import settings
import importlib
class FileResult(BaseModel): class FileResult(BaseModel):
@@ -17,7 +18,7 @@ class Storage:
cls._registry[name] = kclass cls._registry[name] = kclass
@classmethod @classmethod
def get_instance(cls, name, settings_prefix=""): def get_instance(cls, name: str, settings_prefix: str = ""):
if name not in cls._registry: if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}" module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name) importlib.import_module(module_name)
@@ -45,3 +46,9 @@ class Storage:
async def _delete_file(self, filename: str): async def _delete_file(self, filename: str):
raise NotImplementedError raise NotImplementedError
async def get_file_url(self, filename: str) -> str:
return await self._get_file_url(filename)
async def _get_file_url(self, filename: str) -> str:
raise NotImplementedError

View File

@@ -1,6 +1,6 @@
import aioboto3 import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger from reflector.logger import logger
from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage): class AwsStorage(Storage):
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
Body=data, Body=data,
) )
async def _get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url( presigned_url = await client.generate_presigned_url(
"get_object", "get_object",
Params={"Bucket": bucket, "Key": s3filename}, Params={"Bucket": bucket, "Key": s3filename},
ExpiresIn=3600, ExpiresIn=3600,
) )
return FileResult( return presigned_url
filename=filename,
url=presigned_url,
)
async def _delete_file(self, filename: str): async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name bucket = self.aws_bucket_name

View File

@@ -1,12 +1,14 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional
import httpx
import reflector.auth as auth import reflector.auth as auth
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
HTTPException, HTTPException,
Request, Request,
Response,
WebSocket, WebSocket,
WebSocketDisconnect, WebSocketDisconnect,
status, status,
@@ -245,6 +247,42 @@ async def transcript_get_audio_mp3(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if not transcript.audio_mp3_filename.exists(): if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found") raise HTTPException(status_code=500, detail="Audio not found")
@@ -269,8 +307,8 @@ async def transcript_get_audio_waveform(
transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if not transcript.audio_mp3_filename.exists(): if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found") raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform return transcript.audio_waveform

View File

@@ -1,4 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
from tempfile import NamedTemporaryFile
import pytest import pytest
@@ -7,7 +8,6 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def setup_database(): async def setup_database():
from reflector.settings import settings from reflector.settings import settings
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f: with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}" settings.DATABASE_URL = f"sqlite:///{f.name}"
@@ -103,6 +103,25 @@ async def dummy_llm():
yield yield
@pytest.fixture
async def dummy_storage():
from reflector.storage.base import Storage
class DummyStorage(Storage):
async def _put_file(self, *args, **kwargs):
pass
async def _delete_file(self, *args, **kwargs):
pass
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
mock_storage.return_value = DummyStorage()
yield
@pytest.fixture @pytest.fixture
def nltk(): def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
@@ -133,4 +152,17 @@ def celery_enable_logging():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def celery_config(): def celery_config():
return {"broker_url": "memory://", "result_backend": "rpc"} with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield

View File

@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
dummy_transcript, dummy_transcript,
dummy_processors, dummy_processors,
dummy_diarization, dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing, ensure_casing,
appserver, appserver,
sentence_tokenize, sentence_tokenize,
@@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_transcript, dummy_transcript,
dummy_processors, dummy_processors,
dummy_diarization, dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing, ensure_casing,
appserver, appserver,
sentence_tokenize, sentence_tokenize,