server: use mp3 as default for audio storage

Closes #223
This commit is contained in:
2023-09-13 16:18:12 +02:00
committed by Mathieu Virbel
parent fb93c55993
commit 2b9eef6131
4 changed files with 27 additions and 77 deletions

View File

@@ -1,7 +1,8 @@
from reflector.processors.base import Processor
import av
from pathlib import Path from pathlib import Path
import av
from reflector.processors.base import Processor
class AudioFileWriterProcessor(Processor): class AudioFileWriterProcessor(Processor):
""" """
@@ -15,6 +16,8 @@ class AudioFileWriterProcessor(Processor):
super().__init__() super().__init__()
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
if path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.path = path self.path = path
self.out_container = None self.out_container = None
self.out_stream = None self.out_stream = None
@@ -22,10 +25,19 @@ class AudioFileWriterProcessor(Processor):
async def _push(self, data: av.AudioFrame): async def _push(self, data: av.AudioFrame):
if not self.out_container: if not self.out_container:
self.path.parent.mkdir(parents=True, exist_ok=True) self.path.parent.mkdir(parents=True, exist_ok=True)
self.out_container = av.open(self.path.as_posix(), "w", format="wav") suffix = self.path.suffix
self.out_stream = self.out_container.add_stream( if suffix == ".mp3":
"pcm_s16le", rate=data.sample_rate self.out_container = av.open(self.path.as_posix(), "w", format="mp3")
) self.out_stream = self.out_container.add_stream(
"libmp3lame", rate=data.sample_rate
)
elif suffix == ".wav":
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
self.out_stream = self.out_container.add_stream(
"pcm_s16le", rate=data.sample_rate
)
else:
raise ValueError("Only mp3 and wav files are supported")
for packet in self.out_stream.encode(data): for packet in self.out_stream.encode(data):
self.out_container.mux(packet) self.out_container.mux(packet)
await self.emit(data) await self.emit(data)

View File

@@ -1,12 +1,10 @@
import json import json
import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Annotated, Optional from typing import Annotated, Optional
from uuid import uuid4 from uuid import uuid4
import av import reflector.auth as auth
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
@@ -17,13 +15,11 @@ from fastapi import (
) )
from fastapi_pagination import Page, paginate from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts from reflector.db import database, transcripts
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -112,35 +108,12 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"): def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics] return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_mp3(self):
fn = self.audio_mp3_filename
if fn.exists():
return
logger.info(f"Converting audio to mp3: {self.audio_filename}")
inp = av.open(self.audio_filename.as_posix(), "r")
# create temporary file for mp3
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
out = av.open(tmp.name, "w")
stream = out.add_stream("mp3")
for frame in inp.decode(audio=0):
frame.pts = None
for packet in stream.encode(frame):
out.mux(packet)
for packet in stream.encode(None):
out.mux(packet)
out.close()
# move temporary file to final location
shutil.move(tmp.name, fn.as_posix())
def convert_audio_to_waveform(self, segments_count=256): def convert_audio_to_waveform(self, segments_count=256):
fn = self.audio_waveform_filename fn = self.audio_waveform_filename
if fn.exists(): if fn.exists():
return return
waveform = get_audio_waveform( waveform = get_audio_waveform(
path=self.audio_filename, segments_count=segments_count path=self.audio_mp3_filename, segments_count=segments_count
) )
try: try:
with open(fn, "w") as fd: with open(fn, "w") as fd:
@@ -158,10 +131,6 @@ class Transcript(BaseModel):
def data_path(self): def data_path(self):
return Path(settings.DATA_DIR) / self.id return Path(settings.DATA_DIR) / self.id
@property
def audio_filename(self):
return self.data_path / "audio.wav"
@property @property
def audio_mp3_filename(self): def audio_mp3_filename(self):
return self.data_path / "audio.mp3" return self.data_path / "audio.mp3"
@@ -373,27 +342,6 @@ async def transcript_delete(
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@router.get("/transcripts/{transcript_id}/audio")
async def transcript_get_audio(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return range_requests_response(
request,
transcript.audio_filename,
content_type="audio/wav",
)
@router.get("/transcripts/{transcript_id}/audio/mp3") @router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3( async def transcript_get_audio_mp3(
request: Request, request: Request,
@@ -405,11 +353,9 @@ async def transcript_get_audio_mp3(
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists(): if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found") raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_mp3)
return range_requests_response( return range_requests_response(
request, request,
transcript.audio_mp3_filename, transcript.audio_mp3_filename,
@@ -427,7 +373,7 @@ async def transcript_get_audio_waveform(
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists(): if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found") raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform) await run_in_threadpool(transcript.convert_audio_to_waveform)
@@ -640,7 +586,7 @@ async def transcript_record_webrtc(
request, request,
event_callback=handle_rtc_event, event_callback=handle_rtc_event,
event_callback_args=transcript_id, event_callback_args=transcript_id,
audio_filename=transcript.audio_filename, audio_filename=transcript.audio_mp3_filename,
source_language=transcript.source_language, source_language=transcript.source_language,
target_language=transcript.target_language, target_language=transcript.target_language,
) )

View File

@@ -24,8 +24,8 @@ async def fake_transcript(tmpdir):
await transcripts_controller.update(transcript, {"status": "finished"}) await transcripts_controller.update(transcript, {"status": "finished"})
# manually copy a file at the expected location # manually copy a file at the expected location
audio_filename = transcript.audio_filename audio_filename = transcript.audio_mp3_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
audio_filename.parent.mkdir(parents=True, exist_ok=True) audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename) shutil.copy(path, audio_filename)
yield transcript yield transcript
@@ -35,7 +35,6 @@ async def fake_transcript(tmpdir):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"url_suffix,content_type", "url_suffix,content_type",
[ [
["", "audio/wav"],
["/mp3", "audio/mp3"], ["/mp3", "audio/mp3"],
], ],
) )
@@ -52,7 +51,6 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
@pytest.mark.parametrize( @pytest.mark.parametrize(
"url_suffix,content_type", "url_suffix,content_type",
[ [
["", "audio/wav"],
["/mp3", "audio/mp3"], ["/mp3", "audio/mp3"],
], ],
) )
@@ -76,7 +74,6 @@ async def test_transcript_audio_download_range(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"url_suffix,content_type", "url_suffix,content_type",
[ [
["", "audio/wav"],
["/mp3", "audio/mp3"], ["/mp3", "audio/mp3"],
], ],
) )
@@ -104,4 +101,4 @@ async def test_transcript_audio_download_waveform(fake_transcript):
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "application/json" assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list) assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) == 256 assert len(response.json()["data"]) >= 255

View File

@@ -200,11 +200,6 @@ async def test_transcript_rtc_and_websocket(
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "ended" assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available # check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3") resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200 assert resp.status_code == 200