From 2b9eef6131fce45bbe329fbcbc4b07ae6a38532c Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 13 Sep 2023 16:18:12 +0200 Subject: [PATCH] server: use mp3 as default for audio storage Closes #223 --- .../reflector/processors/audio_file_writer.py | 24 +++++-- server/reflector/views/transcripts.py | 66 ++----------------- .../tests/test_transcripts_audio_download.py | 9 +-- server/tests/test_transcripts_rtc_ws.py | 5 -- 4 files changed, 27 insertions(+), 77 deletions(-) diff --git a/server/reflector/processors/audio_file_writer.py b/server/reflector/processors/audio_file_writer.py index 00ab2529..d34dc3f0 100644 --- a/server/reflector/processors/audio_file_writer.py +++ b/server/reflector/processors/audio_file_writer.py @@ -1,7 +1,8 @@ -from reflector.processors.base import Processor -import av from pathlib import Path +import av +from reflector.processors.base import Processor + class AudioFileWriterProcessor(Processor): """ @@ -15,6 +16,8 @@ class AudioFileWriterProcessor(Processor): super().__init__() if isinstance(path, str): path = Path(path) + if path.suffix not in (".mp3", ".wav"): + raise ValueError("Only mp3 and wav files are supported") self.path = path self.out_container = None self.out_stream = None @@ -22,10 +25,19 @@ class AudioFileWriterProcessor(Processor): async def _push(self, data: av.AudioFrame): if not self.out_container: self.path.parent.mkdir(parents=True, exist_ok=True) - self.out_container = av.open(self.path.as_posix(), "w", format="wav") - self.out_stream = self.out_container.add_stream( - "pcm_s16le", rate=data.sample_rate - ) + suffix = self.path.suffix + if suffix == ".mp3": + self.out_container = av.open(self.path.as_posix(), "w", format="mp3") + self.out_stream = self.out_container.add_stream( + "libmp3lame", rate=data.sample_rate + ) + elif suffix == ".wav": + self.out_container = av.open(self.path.as_posix(), "w", format="wav") + self.out_stream = self.out_container.add_stream( + "pcm_s16le", rate=data.sample_rate + ) + else: + raise ValueError("Only mp3 and wav files are supported") for packet in self.out_stream.encode(data): self.out_container.mux(packet) await self.emit(data) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index f4611817..410839d7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,12 +1,10 @@ import json -import shutil from datetime import datetime from pathlib import Path -from tempfile import NamedTemporaryFile from typing import Annotated, Optional from uuid import uuid4 -import av +import reflector.auth as auth from fastapi import ( APIRouter, Depends, @@ -17,13 +15,11 @@ from fastapi import ( ) from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field -from starlette.concurrency import run_in_threadpool - -import reflector.auth as auth from reflector.db import database, transcripts from reflector.logger import logger from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform +from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base @@ -112,35 +108,12 @@ class Transcript(BaseModel): def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] - def convert_audio_to_mp3(self): - fn = self.audio_mp3_filename - if fn.exists(): - return - - logger.info(f"Converting audio to mp3: {self.audio_filename}") - inp = av.open(self.audio_filename.as_posix(), "r") - - # create temporary file for mp3 - with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: - out = av.open(tmp.name, "w") - stream = out.add_stream("mp3") - for frame in inp.decode(audio=0): - frame.pts = None - for packet in stream.encode(frame): - out.mux(packet) - for packet in stream.encode(None): - out.mux(packet) - out.close() - - # move temporary file to final location - shutil.move(tmp.name, fn.as_posix()) - def convert_audio_to_waveform(self, segments_count=256): fn = self.audio_waveform_filename if fn.exists(): return waveform = get_audio_waveform( - path=self.audio_filename, segments_count=segments_count + path=self.audio_mp3_filename, segments_count=segments_count ) try: with open(fn, "w") as fd: @@ -158,10 +131,6 @@ class Transcript(BaseModel): def data_path(self): return Path(settings.DATA_DIR) / self.id - @property - def audio_filename(self): - return self.data_path / "audio.wav" - @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @@ -373,27 +342,6 @@ async def transcript_delete( return DeletionStatus(status="ok") -@router.get("/transcripts/{transcript_id}/audio") -async def transcript_get_audio( - request: Request, - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - if not transcript.audio_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") - - return range_requests_response( - request, - transcript.audio_filename, - content_type="audio/wav", - ) - - @router.get("/transcripts/{transcript_id}/audio/mp3") async def transcript_get_audio_mp3( request: Request, @@ -405,11 +353,9 @@ async def transcript_get_audio_mp3( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - if not transcript.audio_filename.exists(): + if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") - await run_in_threadpool(transcript.convert_audio_to_mp3) - return range_requests_response( request, transcript.audio_mp3_filename, @@ -427,7 +373,7 @@ async def transcript_get_audio_waveform( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - if not transcript.audio_filename.exists(): + if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") await run_in_threadpool(transcript.convert_audio_to_waveform) @@ -640,7 +586,7 @@ async def transcript_record_webrtc( request, event_callback=handle_rtc_event, event_callback_args=transcript_id, - audio_filename=transcript.audio_filename, + audio_filename=transcript.audio_mp3_filename, source_language=transcript.source_language, target_language=transcript.target_language, ) diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index a33ecceb..f37b7a4f 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -24,8 +24,8 @@ async def fake_transcript(tmpdir): await transcripts_controller.update(transcript, {"status": "finished"}) # manually copy a file at the expected location - audio_filename = transcript.audio_filename - path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" + audio_filename = transcript.audio_mp3_filename + path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3" audio_filename.parent.mkdir(parents=True, exist_ok=True) shutil.copy(path, audio_filename) yield transcript @@ -35,7 +35,6 @@ async def fake_transcript(tmpdir): @pytest.mark.parametrize( "url_suffix,content_type", [ - ["", "audio/wav"], ["/mp3", "audio/mp3"], ], ) @@ -52,7 +51,6 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty @pytest.mark.parametrize( "url_suffix,content_type", [ - ["", "audio/wav"], ["/mp3", "audio/mp3"], ], ) @@ -76,7 +74,6 @@ async def test_transcript_audio_download_range( @pytest.mark.parametrize( "url_suffix,content_type", [ - ["", "audio/wav"], ["/mp3", "audio/mp3"], ], ) @@ -104,4 +101,4 @@ async def test_transcript_audio_download_waveform(fake_transcript): assert response.status_code == 200 assert response.headers["content-type"] == "application/json" assert isinstance(response.json()["data"], list) - assert len(response.json()["data"]) == 256 + assert len(response.json()["data"]) >= 255 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index f298e596..d6816192 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -200,11 +200,6 @@ async def test_transcript_rtc_and_websocket( assert resp.status_code == 200 assert resp.json()["status"] == "ended" - # check that audio is available - resp = await ac.get(f"/transcripts/{tid}/audio") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/wav" - # check that audio/mp3 is available resp = await ac.get(f"/transcripts/{tid}/audio/mp3") assert resp.status_code == 200