Merge branch 'main' of github.com:Monadical-SAS/reflector into feat-sharing

This commit is contained in:
Sara
2023-11-22 19:28:45 +01:00
10 changed files with 515 additions and 104 deletions

View File

@@ -0,0 +1,35 @@
"""audio_location
Revision ID: f819277e5169
Revises: 4814901632bc
Create Date: 2023-11-16 10:29:09.351664
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f819277e5169"
down_revision: Union[str, None] = "4814901632bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"audio_location", sa.String(), server_default="local", nullable=False
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "audio_location")
# ### end Alembic commands ###

View File

@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import Storage
transcripts = sqlalchemy.Table(
"transcript",
@@ -28,6 +29,12 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column(
@@ -39,15 +46,22 @@ transcripts = sqlalchemy.Table(
)
def generate_uuid4():
def generate_uuid4() -> str:
return str(uuid4())
def generate_transcript_name():
def generate_transcript_name() -> str:
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
def get_storage() -> Storage:
return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND,
settings_prefix="TRANSCRIPT_STORAGE_",
)
class AudioWaveform(BaseModel):
data: list[float]
@@ -114,6 +128,7 @@ class Transcript(BaseModel):
source_language: str = "en"
target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private"
audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -140,6 +155,10 @@ class Transcript(BaseModel):
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_wav_filename(self):
return self.data_path / "audio.wav"
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@@ -148,6 +167,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property
def audio_waveform(self):
try:
@@ -160,6 +183,40 @@ class Transcript(BaseModel):
return AudioWaveform(data=data)
async def get_audio_url(self) -> str:
if self.audio_location == "local":
return self._generate_local_audio_link()
elif self.audio_location == "storage":
return await self._generate_storage_audio_link()
raise Exception(f"Unknown audio location {self.audio_location}")
async def _generate_storage_audio_link(self) -> str:
return await get_storage().get_file_url(self.storage_audio_path)
def _generate_local_audio_link(self) -> str:
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=self.id,
)
url = f"{settings.BASE_URL}{path}"
if self.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": self.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
return url
class TranscriptController:
async def get_all(
@@ -336,5 +393,22 @@ class TranscriptController:
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
# store the audio on external storage
await get_storage().put_file(
transcript.storage_audio_path,
transcript.audio_mp3_filename.read_bytes(),
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
transcripts_controller = TranscriptController()

View File

@@ -12,13 +12,11 @@ It is directly linked to our data model.
"""
import asyncio
import functools
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from celery import shared_task
from celery import chord, group, shared_task
from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptDuration,
@@ -55,6 +53,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import BoundLogger as Logger
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
coro = f(*args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper
def broadcast_to_sockets(func):
@@ -75,6 +89,26 @@ def broadcast_to_sockets(func):
return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
return await func(transcript=transcript, logger=tlogger, **kwargs)
except Exception as exc:
tlogger.error("Pipeline error", exc_info=exc)
raise
return wrapper
class StrValue(BaseModel):
value: str
@@ -99,6 +133,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
@asynccontextmanager
async def transaction(self):
async with self._lock:
@@ -116,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
"flush": "processing",
"error": "error",
}
elif isinstance(self, PipelineMainDiarization):
elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = {
"push": "processing",
"flush": "processing",
@@ -124,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
"ended": "ended",
}
else:
raise Exception(f"Runner {self.__class__} is missing status mapping")
# intermediate pipeline don't update status
return
# mutate to model status
status = status_mapping.get(status)
@@ -262,9 +310,10 @@ class PipelineMainBase(PipelineRunner):
class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None
source_language: str = "en"
target_language: str = "en"
"""
Main pipeline for live streaming, attach to RTC connection
Any long post process should be done in the post pipeline
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
@@ -274,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
processors = [
AudioFileWriterProcessor(
path=transcript.audio_mp3_filename,
path=transcript.audio_wav_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
@@ -283,26 +332,13 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main live created",
transcript_id=self.transcript_id,
)
pipeline.logger.info("Pipeline main live created")
return pipeline
@@ -310,26 +346,106 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
"""
Diarization is a long time process, so we do it in a separate pipeline
When done, adjust the short and final summary
Diarize the audio and update topics
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = []
if settings.DIARIZATION_ENABLED:
processors += [
AudioDiarizationAutoProcessor(callback=self.on_topic),
]
pipeline = Pipeline(
AudioDiarizationAutoProcessor(callback=self.on_topic),
)
pipeline.options = self
processors += [
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Diarization pipeline created")
self.push(audio_diarization_input)
self.flush()
return pipeline
class PipelineMainFromTopics(PipelineMainBase):
"""
Pseudo class for generating a pipeline from topics
"""
def get_processors(self) -> list:
raise NotImplementedError
async def create(self) -> Pipeline:
self.prepare()
# get transcript
self._transcript = transcript = await self.get_transcript()
# create pipeline
processors = self.get_processors()
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
"""
Generate title from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
)
]
class PipelineMainFinalSummaries(PipelineMainFromTopics):
"""
Generate summaries from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -341,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase):
]
),
]
pipeline = Pipeline(*processors)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
topics = [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
class PipelineMainWaveform(PipelineMainFromTopics):
"""
Generate waveform
"""
def get_processors(self) -> list:
return [
AudioWaveformProcessor.as_threaded(
audio_path=self._transcript.audio_wav_filename,
waveform_path=self._transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
),
]
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=transcript.id,
)
url = f"{settings.BASE_URL}{path}"
if transcript.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_url=url,
topics=topics,
)
@get_transcript
async def pipeline_waveform(transcript: Transcript, logger: Logger):
logger.info("Starting waveform")
runner = PipelineMainWaveform(transcript_id=transcript.id)
await runner.run()
logger.info("Waveform done")
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main post created", transcript_id=self.transcript_id
)
self.push(audio_diarization_input)
self.flush()
return pipeline
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
wav_filename = transcript.audio_wav_filename
if not wav_filename.exists():
logger.warning("Wav file not found, may be already converted")
return
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container:
out_stream = out_container.add_stream("mp3")
for frame in in_container.decode(in_stream):
for packet in out_stream.encode(frame):
out_container.mux(packet)
# Delete the wav file
transcript.audio_wav_filename.unlink(missing_ok=True)
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
mp3_filename = transcript.audio_mp3_filename
if not mp3_filename.exists():
logger.warning("Mp3 file not found, may be already uploaded")
return
# Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done")
@get_transcript
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run()
logger.info("Diarization done")
@get_transcript
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
logger.info("Starting title and short summary")
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
await runner.run()
logger.info("Title and short summary done")
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run()
logger.info("Summaries done")
# ===================================================================
# Celery tasks that can be called from the API
# ===================================================================
@shared_task
def task_pipeline_main_post(transcript_id: str):
logger.info(
"Starting main post pipeline",
transcript_id=transcript_id,
@asynctask
async def task_pipeline_waveform(*, transcript_id: str):
await pipeline_waveform(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
await pipeline_title_and_short_summary(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
def pipeline_post(*, transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_waveform.si(transcript_id=transcript_id)
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
)
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()
chain_title_preview = task_pipeline_title_and_short_summary.si(
transcript_id=transcript_id
)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
)
chain.delay()

View File

@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
if not self.pipeline:
self.pipeline = await self.create()
if not self.pipeline:
# no pipeline created in create, just finish it then.
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
return
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
@@ -119,8 +127,7 @@ class PipelineRunner(BaseModel):
self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
raise
async def cmd_push(self, data):
if self._is_first_push:

View File

@@ -54,7 +54,7 @@ class Settings(BaseSettings):
TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage
TRANSCRIPT_STORAGE_BACKEND: str = "aws"
TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Transcript MP3 storage
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
# LLM
# available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga"

View File

@@ -1,6 +1,7 @@
import importlib
from pydantic import BaseModel
from reflector.settings import settings
import importlib
class FileResult(BaseModel):
@@ -17,7 +18,7 @@ class Storage:
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name, settings_prefix=""):
def get_instance(cls, name: str, settings_prefix: str = ""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
@@ -45,3 +46,9 @@ class Storage:
async def _delete_file(self, filename: str):
raise NotImplementedError
async def get_file_url(self, filename: str) -> str:
return await self._get_file_url(filename)
async def _get_file_url(self, filename: str) -> str:
raise NotImplementedError

View File

@@ -1,6 +1,6 @@
import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger
from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage):
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
Body=data,
)
async def _get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": s3filename},
ExpiresIn=3600,
)
return FileResult(
filename=filename,
url=presigned_url,
)
return presigned_url
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name

View File

@@ -1,12 +1,14 @@
from datetime import datetime, timedelta
from typing import Annotated, Literal, Optional
import httpx
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
Response,
WebSocket,
WebSocketDisconnect,
status,
@@ -245,6 +247,42 @@ async def transcript_get_audio_mp3(
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
@@ -269,8 +307,8 @@ async def transcript_get_audio_waveform(
transcript_id, user_id=user_id
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform

View File

@@ -1,4 +1,5 @@
from unittest.mock import patch
from tempfile import NamedTemporaryFile
import pytest
@@ -7,7 +8,6 @@ import pytest
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
@@ -103,6 +103,25 @@ async def dummy_llm():
yield
@pytest.fixture
async def dummy_storage():
from reflector.storage.base import Storage
class DummyStorage(Storage):
async def _put_file(self, *args, **kwargs):
pass
async def _delete_file(self, *args, **kwargs):
pass
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
mock_storage.return_value = DummyStorage()
yield
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
@@ -133,4 +152,17 @@ def celery_enable_logging():
@pytest.fixture(scope="session")
def celery_config():
return {"broker_url": "memory://", "result_backend": "rpc"}
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield

View File

@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
@@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,