server: implement wav/mp3 audio download

If set, will save audio transcription to disk.
MP3 conversion is on-request, but cached to disk as well only if it is successfull.

Closes #148
This commit is contained in:
2023-08-15 19:01:39 +02:00
committed by Mathieu Virbel
parent 290b552479
commit a809e5e734
6 changed files with 134 additions and 6 deletions

View File

@@ -1,5 +1,6 @@
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401

View File

@@ -0,0 +1,35 @@
from reflector.processors.base import Processor
import av
import wave
from pathlib import Path
class AudioFileWriterProcessor(Processor):
"""
Write audio frames to a file.
"""
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
if isinstance(path, str):
path = Path(path)
self.path = path
self.fd = None
async def _push(self, data: av.AudioFrame):
if not self.fd:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.fd = wave.open(self.path.as_posix(), "wb")
self.fd.setnchannels(len(data.layout.channels))
self.fd.setsampwidth(data.format.bytes)
self.fd.setframerate(data.sample_rate)
self.fd.writeframes(data.to_ndarray().tobytes())
await self.emit(data)
async def _flush(self):
if self.fd:
self.fd.close()
self.fd = None

View File

@@ -9,6 +9,9 @@ class Settings(BaseSettings):
# Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
# local data directory (audio for no)
DATA_DIR: str = "./data"
# Whisper
WHISPER_MODEL_SIZE: str = "tiny"
WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny"

View File

@@ -7,12 +7,14 @@ from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps
from enum import StrEnum
from pathlib import Path
import av
from reflector.processors import (
Pipeline,
AudioChunkerProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
AudioFileWriterProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptFinalSummaryProcessor,
@@ -64,7 +66,11 @@ class PipelineEvent(StrEnum):
async def rtc_offer_base(
params: RtcOffer, request: Request, event_callback=None, event_callback_args=None
params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -151,14 +157,18 @@ async def rtc_offer_base(
# create a context for the whole rtc transaction
# add a customised logger to the context
ctx.pipeline = Pipeline(
processors = []
if audio_filename is not None:
processors += [AudioFileWriterProcessor(path=audio_filename)]
processors += [
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
)
]
ctx.pipeline = Pipeline(*processors)
# FIXME: warmup is not working well yet
# await ctx.pipeline.warmup()

View File

@@ -5,14 +5,20 @@ from fastapi import (
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from uuid import uuid4
from datetime import datetime
from fastapi_pagination import Page, paginate
from reflector.logger import logger
from reflector.db import database, transcripts
from reflector.settings import settings
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Optional
from pathlib import Path
from tempfile import NamedTemporaryFile
import av
router = APIRouter()
@@ -81,6 +87,44 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_mp3(self):
fn = self.audio_mp3_filename
if fn.exists():
return
logger.info(f"Converting audio to mp3: {self.audio_filename}")
inp = av.open(self.audio_filename.as_posix(), "r")
# create temporary file for mp3
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
out = av.open(tmp.name, "w")
stream = out.add_stream("mp3")
for frame in inp.decode(audio=0):
frame.pts = None
for packet in stream.encode(frame):
out.mux(packet)
for packet in stream.encode(None):
out.mux(packet)
out.close()
# move temporary file to final location
Path(tmp.name).rename(fn)
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_filename(self):
return self.data_path / "audio.wav"
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
class TranscriptController:
async def get_all(self) -> list[Transcript]:
@@ -112,6 +156,10 @@ class TranscriptController:
setattr(transcript, key, value)
async def remove_by_id(self, transcript_id: str) -> None:
transcript = await self.get_by_id(transcript_id)
if not transcript:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
@@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str):
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# TODO: Implement audio generation
return HTTPException(status_code=500, detail="Not implemented")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return FileResponse(transcript.audio_filename, media_type="audio/wav")
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_mp3)
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
@@ -371,4 +435,5 @@ async def transcript_record_webrtc(
request,
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_filename,
)

View File

@@ -70,11 +70,15 @@ async def dummy_llm():
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
from reflector.settings import settings
settings.DATA_DIR = Path(tmpdir)
# start server
host = "127.0.0.1"
port = 1255
@@ -188,3 +192,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"