mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: implement wav/mp3 audio download
If set, will save audio transcription to disk. MP3 conversion is on-request, but cached to disk as well only if it is successfull. Closes #148
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
|
||||
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
|
||||
35
server/reflector/processors/audio_file_writer.py
Normal file
35
server/reflector/processors/audio_file_writer.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from reflector.processors.base import Processor
|
||||
import av
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class AudioFileWriterProcessor(Processor):
|
||||
"""
|
||||
Write audio frames to a file.
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = av.AudioFrame
|
||||
|
||||
def __init__(self, path: Path | str):
|
||||
super().__init__()
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.path = path
|
||||
self.fd = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
if not self.fd:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.fd = wave.open(self.path.as_posix(), "wb")
|
||||
self.fd.setnchannels(len(data.layout.channels))
|
||||
self.fd.setsampwidth(data.format.bytes)
|
||||
self.fd.setframerate(data.sample_rate)
|
||||
self.fd.writeframes(data.to_ndarray().tobytes())
|
||||
await self.emit(data)
|
||||
|
||||
async def _flush(self):
|
||||
if self.fd:
|
||||
self.fd.close()
|
||||
self.fd = None
|
||||
@@ -9,6 +9,9 @@ class Settings(BaseSettings):
|
||||
# Database
|
||||
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
|
||||
|
||||
# local data directory (audio for no)
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
# Whisper
|
||||
WHISPER_MODEL_SIZE: str = "tiny"
|
||||
WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny"
|
||||
|
||||
@@ -7,12 +7,14 @@ from reflector.logger import logger
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||
from json import loads, dumps
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
import av
|
||||
from reflector.processors import (
|
||||
Pipeline,
|
||||
AudioChunkerProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
@@ -64,7 +66,11 @@ class PipelineEvent(StrEnum):
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
params: RtcOffer, request: Request, event_callback=None, event_callback_args=None
|
||||
params: RtcOffer,
|
||||
request: Request,
|
||||
event_callback=None,
|
||||
event_callback_args=None,
|
||||
audio_filename: Path | None = None,
|
||||
):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
@@ -151,14 +157,18 @@ async def rtc_offer_base(
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
ctx.pipeline = Pipeline(
|
||||
processors = []
|
||||
if audio_filename is not None:
|
||||
processors += [AudioFileWriterProcessor(path=audio_filename)]
|
||||
processors += [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||
)
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
# FIXME: warmup is not working well yet
|
||||
# await ctx.pipeline.warmup()
|
||||
|
||||
|
||||
@@ -5,14 +5,20 @@ from fastapi import (
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from fastapi_pagination import Page, paginate
|
||||
from reflector.logger import logger
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.settings import settings
|
||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
import av
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -81,6 +87,44 @@ class Transcript(BaseModel):
|
||||
def topics_dump(self, mode="json"):
|
||||
return [topic.model_dump(mode=mode) for topic in self.topics]
|
||||
|
||||
def convert_audio_to_mp3(self):
|
||||
fn = self.audio_mp3_filename
|
||||
if fn.exists():
|
||||
return
|
||||
|
||||
logger.info(f"Converting audio to mp3: {self.audio_filename}")
|
||||
inp = av.open(self.audio_filename.as_posix(), "r")
|
||||
|
||||
# create temporary file for mp3
|
||||
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
|
||||
out = av.open(tmp.name, "w")
|
||||
stream = out.add_stream("mp3")
|
||||
for frame in inp.decode(audio=0):
|
||||
frame.pts = None
|
||||
for packet in stream.encode(frame):
|
||||
out.mux(packet)
|
||||
for packet in stream.encode(None):
|
||||
out.mux(packet)
|
||||
out.close()
|
||||
|
||||
# move temporary file to final location
|
||||
Path(tmp.name).rename(fn)
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def data_path(self):
|
||||
return Path(settings.DATA_DIR) / self.id
|
||||
|
||||
@property
|
||||
def audio_filename(self):
|
||||
return self.data_path / "audio.wav"
|
||||
|
||||
@property
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self) -> list[Transcript]:
|
||||
@@ -112,6 +156,10 @@ class TranscriptController:
|
||||
setattr(transcript, key, value)
|
||||
|
||||
async def remove_by_id(self, transcript_id: str) -> None:
|
||||
transcript = await self.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
return
|
||||
transcript.unlink()
|
||||
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||
await database.execute(query)
|
||||
|
||||
@@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str):
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# TODO: Implement audio generation
|
||||
return HTTPException(status_code=500, detail="Not implemented")
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
return FileResponse(transcript.audio_filename, media_type="audio/wav")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||
async def transcript_get_audio_mp3(transcript_id: str):
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
||||
|
||||
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
@@ -371,4 +435,5 @@ async def transcript_record_webrtc(
|
||||
request,
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_filename,
|
||||
)
|
||||
|
||||
@@ -70,11 +70,15 @@ async def dummy_llm():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# start server
|
||||
host = "127.0.0.1"
|
||||
port = 1255
|
||||
@@ -188,3 +192,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
# check that audio is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/wav"
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mp3"
|
||||
|
||||
Reference in New Issue
Block a user