mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
@@ -1,7 +1,8 @@
|
||||
from reflector.processors.base import Processor
|
||||
import av
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
class AudioFileWriterProcessor(Processor):
|
||||
"""
|
||||
@@ -15,6 +16,8 @@ class AudioFileWriterProcessor(Processor):
|
||||
super().__init__()
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if path.suffix not in (".mp3", ".wav"):
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
self.path = path
|
||||
self.out_container = None
|
||||
self.out_stream = None
|
||||
@@ -22,10 +25,19 @@ class AudioFileWriterProcessor(Processor):
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if not self.out_container:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
|
||||
self.out_stream = self.out_container.add_stream(
|
||||
"pcm_s16le", rate=data.sample_rate
|
||||
)
|
||||
suffix = self.path.suffix
|
||||
if suffix == ".mp3":
|
||||
self.out_container = av.open(self.path.as_posix(), "w", format="mp3")
|
||||
self.out_stream = self.out_container.add_stream(
|
||||
"libmp3lame", rate=data.sample_rate
|
||||
)
|
||||
elif suffix == ".wav":
|
||||
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
|
||||
self.out_stream = self.out_container.add_stream(
|
||||
"pcm_s16le", rate=data.sample_rate
|
||||
)
|
||||
else:
|
||||
raise ValueError("Only mp3 and wav files are supported")
|
||||
for packet in self.out_stream.encode(data):
|
||||
self.out_container.mux(packet)
|
||||
await self.emit(data)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import av
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -17,13 +15,11 @@ from fastapi import (
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
@@ -112,35 +108,12 @@ class Transcript(BaseModel):
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_mp3(self):
|
||||
fn = self.audio_mp3_filename
|
||||
if fn.exists():
|
||||
return
|
||||
|
||||
logger.info(f"Converting audio to mp3: {self.audio_filename}")
|
||||
inp = av.open(self.audio_filename.as_posix(), "r")
|
||||
|
||||
# create temporary file for mp3
|
||||
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
|
||||
out = av.open(tmp.name, "w")
|
||||
stream = out.add_stream("mp3")
|
||||
for frame in inp.decode(audio=0):
|
||||
frame.pts = None
|
||||
for packet in stream.encode(frame):
|
||||
out.mux(packet)
|
||||
for packet in stream.encode(None):
|
||||
out.mux(packet)
|
||||
out.close()
|
||||
|
||||
# move temporary file to final location
|
||||
shutil.move(tmp.name, fn.as_posix())
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=256):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_filename, segments_count=segments_count
|
||||
path=self.audio_mp3_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
@@ -158,10 +131,6 @@ class Transcript(BaseModel):
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_filename(self):
|
||||
return self.data_path / "audio.wav"
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
@@ -373,27 +342,6 @@ async def transcript_delete(
|
||||
return DeletionStatus(status="ok")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio")
|
||||
async def transcript_get_audio(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
return range_requests_response(
|
||||
request,
|
||||
transcript.audio_filename,
|
||||
content_type="audio/wav",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||
async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
@@ -405,11 +353,9 @@ async def transcript_get_audio_mp3(
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
||||
|
||||
return range_requests_response(
|
||||
request,
|
||||
transcript.audio_mp3_filename,
|
||||
@@ -427,7 +373,7 @@ async def transcript_get_audio_waveform(
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
@@ -640,7 +586,7 @@ async def transcript_record_webrtc(
|
||||
request,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_filename,
|
||||
audio_filename=transcript.audio_mp3_filename,
|
||||
source_language=transcript.source_language,
|
||||
target_language=transcript.target_language,
|
||||
)
|
||||
|
||||
@@ -24,8 +24,8 @@ async def fake_transcript(tmpdir):
|
||||
await transcripts_controller.update(transcript, {"status": "finished"})
|
||||
|
||||
# manually copy a file at the expected location
|
||||
audio_filename = transcript.audio_filename
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
audio_filename = transcript.audio_mp3_filename
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(path, audio_filename)
|
||||
yield transcript
|
||||
@@ -35,7 +35,6 @@ async def fake_transcript(tmpdir):
|
||||
@pytest.mark.parametrize(
|
||||
"url_suffix,content_type",
|
||||
[
|
||||
["", "audio/wav"],
|
||||
["/mp3", "audio/mp3"],
|
||||
],
|
||||
)
|
||||
@@ -52,7 +51,6 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
|
||||
@pytest.mark.parametrize(
|
||||
"url_suffix,content_type",
|
||||
[
|
||||
["", "audio/wav"],
|
||||
["/mp3", "audio/mp3"],
|
||||
],
|
||||
)
|
||||
@@ -76,7 +74,6 @@ async def test_transcript_audio_download_range(
|
||||
@pytest.mark.parametrize(
|
||||
"url_suffix,content_type",
|
||||
[
|
||||
["", "audio/wav"],
|
||||
["/mp3", "audio/mp3"],
|
||||
],
|
||||
)
|
||||
@@ -104,4 +101,4 @@ async def test_transcript_audio_download_waveform(fake_transcript):
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
assert isinstance(response.json()["data"], list)
|
||||
assert len(response.json()["data"]) == 256
|
||||
assert len(response.json()["data"]) >= 255
|
||||
|
||||
@@ -200,11 +200,6 @@ async def test_transcript_rtc_and_websocket(
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
# check that audio is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/wav"
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user