server: implement wav/mp3 audio download

If set, will save audio transcription to disk.
MP3 conversion is on-request, but cached to disk as well only if it is successfull.

Closes #148
This commit is contained in:
2023-08-15 19:01:39 +02:00
committed by Mathieu Virbel
parent 290b552479
commit a809e5e734
6 changed files with 134 additions and 6 deletions

View File

@@ -1,5 +1,6 @@
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401 from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401 from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401

View File

@@ -0,0 +1,35 @@
from reflector.processors.base import Processor
import av
import wave
from pathlib import Path
class AudioFileWriterProcessor(Processor):
"""
Write audio frames to a file.
"""
INPUT_TYPE = av.AudioFrame
OUTPUT_TYPE = av.AudioFrame
def __init__(self, path: Path | str):
super().__init__()
if isinstance(path, str):
path = Path(path)
self.path = path
self.fd = None
async def _push(self, data: av.AudioFrame):
if not self.fd:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.fd = wave.open(self.path.as_posix(), "wb")
self.fd.setnchannels(len(data.layout.channels))
self.fd.setsampwidth(data.format.bytes)
self.fd.setframerate(data.sample_rate)
self.fd.writeframes(data.to_ndarray().tobytes())
await self.emit(data)
async def _flush(self):
if self.fd:
self.fd.close()
self.fd = None

View File

@@ -9,6 +9,9 @@ class Settings(BaseSettings):
# Database # Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3" DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
# local data directory (audio for no)
DATA_DIR: str = "./data"
# Whisper # Whisper
WHISPER_MODEL_SIZE: str = "tiny" WHISPER_MODEL_SIZE: str = "tiny"
WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny" WHISPER_REAL_TIME_MODEL_SIZE: str = "tiny"

View File

@@ -7,12 +7,14 @@ from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps from json import loads, dumps
from enum import StrEnum from enum import StrEnum
from pathlib import Path
import av import av
from reflector.processors import ( from reflector.processors import (
Pipeline, Pipeline,
AudioChunkerProcessor, AudioChunkerProcessor,
AudioMergeProcessor, AudioMergeProcessor,
AudioTranscriptAutoProcessor, AudioTranscriptAutoProcessor,
AudioFileWriterProcessor,
TranscriptLinerProcessor, TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptFinalSummaryProcessor, TranscriptFinalSummaryProcessor,
@@ -64,7 +66,11 @@ class PipelineEvent(StrEnum):
async def rtc_offer_base( async def rtc_offer_base(
params: RtcOffer, request: Request, event_callback=None, event_callback_args=None params: RtcOffer,
request: Request,
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
): ):
# build an rtc session # build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type) offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -151,14 +157,18 @@ async def rtc_offer_base(
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
ctx.pipeline = Pipeline( processors = []
if audio_filename is not None:
processors += [AudioFileWriterProcessor(path=audio_filename)]
processors += [
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
) ]
ctx.pipeline = Pipeline(*processors)
# FIXME: warmup is not working well yet # FIXME: warmup is not working well yet
# await ctx.pipeline.warmup() # await ctx.pipeline.warmup()

View File

@@ -5,14 +5,20 @@ from fastapi import (
WebSocket, WebSocket,
WebSocketDisconnect, WebSocketDisconnect,
) )
from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from uuid import uuid4 from uuid import uuid4
from datetime import datetime from datetime import datetime
from fastapi_pagination import Page, paginate from fastapi_pagination import Page, paginate
from reflector.logger import logger from reflector.logger import logger
from reflector.db import database, transcripts from reflector.db import database, transcripts
from reflector.settings import settings
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Optional from typing import Optional
from pathlib import Path
from tempfile import NamedTemporaryFile
import av
router = APIRouter() router = APIRouter()
@@ -81,6 +87,44 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"): def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics] return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_mp3(self):
fn = self.audio_mp3_filename
if fn.exists():
return
logger.info(f"Converting audio to mp3: {self.audio_filename}")
inp = av.open(self.audio_filename.as_posix(), "r")
# create temporary file for mp3
with NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
out = av.open(tmp.name, "w")
stream = out.add_stream("mp3")
for frame in inp.decode(audio=0):
frame.pts = None
for packet in stream.encode(frame):
out.mux(packet)
for packet in stream.encode(None):
out.mux(packet)
out.close()
# move temporary file to final location
Path(tmp.name).rename(fn)
def unlink(self):
self.data_path.unlink(missing_ok=True)
@property
def data_path(self):
return Path(settings.DATA_DIR) / self.id
@property
def audio_filename(self):
return self.data_path / "audio.wav"
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
class TranscriptController: class TranscriptController:
async def get_all(self) -> list[Transcript]: async def get_all(self) -> list[Transcript]:
@@ -112,6 +156,10 @@ class TranscriptController:
setattr(transcript, key, value) setattr(transcript, key, value)
async def remove_by_id(self, transcript_id: str) -> None: async def remove_by_id(self, transcript_id: str) -> None:
transcript = await self.get_by_id(transcript_id)
if not transcript:
return
transcript.unlink()
query = transcripts.delete().where(transcripts.c.id == transcript_id) query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query) await database.execute(query)
@@ -202,8 +250,24 @@ async def transcript_get_audio(transcript_id: str):
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
# TODO: Implement audio generation if not transcript.audio_filename.exists():
return HTTPException(status_code=500, detail="Not implemented") raise HTTPException(status_code=404, detail="Audio not found")
return FileResponse(transcript.audio_filename, media_type="audio/wav")
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_mp3)
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
@@ -371,4 +435,5 @@ async def transcript_record_webrtc(
request, request,
event_callback=handle_rtc_event, event_callback=handle_rtc_event,
event_callback_args=transcript_id, event_callback_args=transcript_id,
audio_filename=transcript.audio_filename,
) )

View File

@@ -70,11 +70,15 @@ async def dummy_llm():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm): async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events # goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread # because of that, we need to start the server in a thread
# to be able to connect with aiortc # to be able to connect with aiortc
from reflector.settings import settings
settings.DATA_DIR = Path(tmpdir)
# start server # start server
host = "127.0.0.1" host = "127.0.0.1"
port = 1255 port = 1255
@@ -188,3 +192,13 @@ async def test_transcript_rtc_and_websocket(dummy_transcript, dummy_llm):
resp = await ac.get(f"/transcripts/{tid}") resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["status"] == "ended" assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"