mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
@@ -1,7 +1,8 @@
|
|||||||
from reflector.processors.base import Processor
|
|
||||||
import av
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
|
||||||
|
|
||||||
class AudioFileWriterProcessor(Processor):
|
class AudioFileWriterProcessor(Processor):
|
||||||
"""
|
"""
|
||||||
@@ -15,6 +16,8 @@ class AudioFileWriterProcessor(Processor):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
if path.suffix not in (".mp3", ".wav"):
|
||||||
|
raise ValueError("Only mp3 and wav files are supported")
|
||||||
self.path = path
|
self.path = path
|
||||||
self.out_container = None
|
self.out_container = None
|
||||||
self.out_stream = None
|
self.out_stream = None
|
||||||
@@ -22,10 +25,19 @@ class AudioFileWriterProcessor(Processor):
|
|||||||
async def _push(self, data: av.AudioFrame):
|
async def _push(self, data: av.AudioFrame):
|
||||||
if not self.out_container:
|
if not self.out_container:
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
|
suffix = self.path.suffix
|
||||||
self.out_stream = self.out_container.add_stream(
|
if suffix == ".mp3":
|
||||||
"pcm_s16le", rate=data.sample_rate
|
self.out_container = av.open(self.path.as_posix(), "w", format="mp3")
|
||||||
)
|
self.out_stream = self.out_container.add_stream(
|
||||||
|
"libmp3lame", rate=data.sample_rate
|
||||||
|
)
|
||||||
|
elif suffix == ".wav":
|
||||||
|
self.out_container = av.open(self.path.as_posix(), "w", format="wav")
|
||||||
|
self.out_stream = self.out_container.add_stream(
|
||||||
|
"pcm_s16le", rate=data.sample_rate
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Only mp3 and wav files are supported")
|
||||||
for packet in self.out_stream.encode(data):
|
for packet in self.out_stream.encode(data):
|
||||||
self.out_container.mux(packet)
|
self.out_container.mux(packet)
|
||||||
await self.emit(data)
|
await self.emit(data)
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import shutil
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import av
|
import reflector.auth as auth
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
@@ -17,13 +15,11 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi_pagination import Page, paginate
|
from fastapi_pagination import Page, paginate
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from starlette.concurrency import run_in_threadpool
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
|
||||||
from reflector.db import database, transcripts
|
from reflector.db import database, transcripts
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
from ._range_requests_response import range_requests_response
|
from ._range_requests_response import range_requests_response
|
||||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||||
@@ -112,35 +108,12 @@ class Transcript(BaseModel):
|
|||||||
def topics_dump(self, mode="json"):
|
def topics_dump(self, mode="json"):
|
||||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||||
|
|
||||||
def convert_audio_to_mp3(self):
|
|
||||||
fn = self.audio_mp3_filename
|
|
||||||
if fn.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Converting audio to mp3: {self.audio_filename}")
|
|
||||||
inp = av.open(self.audio_filename.as_posix(), "r")
|
|
||||||
|
|
||||||
# create temporary file for mp3
|
|
||||||
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
|
|
||||||
out = av.open(tmp.name, "w")
|
|
||||||
stream = out.add_stream("mp3")
|
|
||||||
for frame in inp.decode(audio=0):
|
|
||||||
frame.pts = None
|
|
||||||
for packet in stream.encode(frame):
|
|
||||||
out.mux(packet)
|
|
||||||
for packet in stream.encode(None):
|
|
||||||
out.mux(packet)
|
|
||||||
out.close()
|
|
||||||
|
|
||||||
# move temporary file to final location
|
|
||||||
shutil.move(tmp.name, fn.as_posix())
|
|
||||||
|
|
||||||
def convert_audio_to_waveform(self, segments_count=256):
|
def convert_audio_to_waveform(self, segments_count=256):
|
||||||
fn = self.audio_waveform_filename
|
fn = self.audio_waveform_filename
|
||||||
if fn.exists():
|
if fn.exists():
|
||||||
return
|
return
|
||||||
waveform = get_audio_waveform(
|
waveform = get_audio_waveform(
|
||||||
path=self.audio_filename, segments_count=segments_count
|
path=self.audio_mp3_filename, segments_count=segments_count
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with open(fn, "w") as fd:
|
with open(fn, "w") as fd:
|
||||||
@@ -158,10 +131,6 @@ class Transcript(BaseModel):
|
|||||||
def data_path(self):
|
def data_path(self):
|
||||||
return Path(settings.DATA_DIR) / self.id
|
return Path(settings.DATA_DIR) / self.id
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_filename(self):
|
|
||||||
return self.data_path / "audio.wav"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_mp3_filename(self):
|
def audio_mp3_filename(self):
|
||||||
return self.data_path / "audio.mp3"
|
return self.data_path / "audio.mp3"
|
||||||
@@ -373,27 +342,6 @@ async def transcript_delete(
|
|||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/audio")
|
|
||||||
async def transcript_get_audio(
|
|
||||||
request: Request,
|
|
||||||
transcript_id: str,
|
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
|
||||||
):
|
|
||||||
user_id = user["sub"] if user else None
|
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
|
||||||
if not transcript:
|
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
|
||||||
|
|
||||||
if not transcript.audio_filename.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="Audio not found")
|
|
||||||
|
|
||||||
return range_requests_response(
|
|
||||||
request,
|
|
||||||
transcript.audio_filename,
|
|
||||||
content_type="audio/wav",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||||
async def transcript_get_audio_mp3(
|
async def transcript_get_audio_mp3(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -405,11 +353,9 @@ async def transcript_get_audio_mp3(
|
|||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
if not transcript.audio_filename.exists():
|
if not transcript.audio_mp3_filename.exists():
|
||||||
raise HTTPException(status_code=404, detail="Audio not found")
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
|
||||||
|
|
||||||
return range_requests_response(
|
return range_requests_response(
|
||||||
request,
|
request,
|
||||||
transcript.audio_mp3_filename,
|
transcript.audio_mp3_filename,
|
||||||
@@ -427,7 +373,7 @@ async def transcript_get_audio_waveform(
|
|||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
if not transcript.audio_filename.exists():
|
if not transcript.audio_mp3_filename.exists():
|
||||||
raise HTTPException(status_code=404, detail="Audio not found")
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||||
@@ -640,7 +586,7 @@ async def transcript_record_webrtc(
|
|||||||
request,
|
request,
|
||||||
event_callback=handle_rtc_event,
|
event_callback=handle_rtc_event,
|
||||||
event_callback_args=transcript_id,
|
event_callback_args=transcript_id,
|
||||||
audio_filename=transcript.audio_filename,
|
audio_filename=transcript.audio_mp3_filename,
|
||||||
source_language=transcript.source_language,
|
source_language=transcript.source_language,
|
||||||
target_language=transcript.target_language,
|
target_language=transcript.target_language,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ async def fake_transcript(tmpdir):
|
|||||||
await transcripts_controller.update(transcript, {"status": "finished"})
|
await transcripts_controller.update(transcript, {"status": "finished"})
|
||||||
|
|
||||||
# manually copy a file at the expected location
|
# manually copy a file at the expected location
|
||||||
audio_filename = transcript.audio_filename
|
audio_filename = transcript.audio_mp3_filename
|
||||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||||
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(path, audio_filename)
|
shutil.copy(path, audio_filename)
|
||||||
yield transcript
|
yield transcript
|
||||||
@@ -35,7 +35,6 @@ async def fake_transcript(tmpdir):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"url_suffix,content_type",
|
"url_suffix,content_type",
|
||||||
[
|
[
|
||||||
["", "audio/wav"],
|
|
||||||
["/mp3", "audio/mp3"],
|
["/mp3", "audio/mp3"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -52,7 +51,6 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"url_suffix,content_type",
|
"url_suffix,content_type",
|
||||||
[
|
[
|
||||||
["", "audio/wav"],
|
|
||||||
["/mp3", "audio/mp3"],
|
["/mp3", "audio/mp3"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -76,7 +74,6 @@ async def test_transcript_audio_download_range(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"url_suffix,content_type",
|
"url_suffix,content_type",
|
||||||
[
|
[
|
||||||
["", "audio/wav"],
|
|
||||||
["/mp3", "audio/mp3"],
|
["/mp3", "audio/mp3"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -104,4 +101,4 @@ async def test_transcript_audio_download_waveform(fake_transcript):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "application/json"
|
assert response.headers["content-type"] == "application/json"
|
||||||
assert isinstance(response.json()["data"], list)
|
assert isinstance(response.json()["data"], list)
|
||||||
assert len(response.json()["data"]) == 256
|
assert len(response.json()["data"]) >= 255
|
||||||
|
|||||||
@@ -200,11 +200,6 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["status"] == "ended"
|
assert resp.json()["status"] == "ended"
|
||||||
|
|
||||||
# check that audio is available
|
|
||||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.headers["Content-Type"] == "audio/wav"
|
|
||||||
|
|
||||||
# check that audio/mp3 is available
|
# check that audio/mp3 is available
|
||||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user