From c21d88b79796b6e216cc78beb3af258f5a1cc7b6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 23 Aug 2023 12:32:43 +0200 Subject: [PATCH] server: implement audio waveform download as json Closes #136 --- server/reflector/utils/audio_waveform.py | 73 +++++++++++++++++++ server/reflector/views/transcripts.py | 59 ++++++++++++++- .../tests/test_transcripts_audio_download.py | 12 +++ 3 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 server/reflector/utils/audio_waveform.py diff --git a/server/reflector/utils/audio_waveform.py b/server/reflector/utils/audio_waveform.py new file mode 100644 index 00000000..c94a4f72 --- /dev/null +++ b/server/reflector/utils/audio_waveform.py @@ -0,0 +1,73 @@ +from pathlib import Path + +import av +import numpy as np + + +def get_audio_waveform(path: Path | str, segments_count: int = 1000) -> list[int]: + if isinstance(path, Path): + path = path.as_posix() + + container = av.open(path) + stream = container.streams.get(audio=0)[0] + duration = container.duration / av.time_base + + chunk_size_secs = duration / segments_count + chunk_size = int(chunk_size_secs * stream.rate * stream.channels) + if chunk_size == 0: + # there is not enough data to fill the chunks + # so basically we use chunk_size of 1. + chunk_size = 1 + + # 1.1 is a safety margin as it seems that pyav decode + # does not always return the exact number of chunks + # that we expect. + volumes = np.zeros(int(segments_count * 1.1), dtype=int) + current_chunk_idx = 0 + current_chunk_size = 0 + current_chunk_volume = 0 + + count = 0 + frames = 0 + samples = 0 + for frame in container.decode(stream): + data = frame.to_ndarray().flatten() + count += len(data) + frames += 1 + samples += frame.samples + + while len(data) > 0: + datalen = len(data) + + # check how much we need to fill the chunk + chunk_remaining = chunk_size - current_chunk_size + if chunk_remaining > 0: + volume = np.absolute(data[:chunk_remaining]).max() + data = data[chunk_remaining:] + current_chunk_volume = max(current_chunk_volume, volume) + current_chunk_size += min(chunk_remaining, datalen) + + if current_chunk_size == chunk_size: + # chunk is full, add it to the volumes + volumes[current_chunk_idx] = current_chunk_volume + current_chunk_idx += 1 + current_chunk_size = 0 + current_chunk_volume = 0 + + volumes = volumes[:current_chunk_idx] + + # normalize the volumes 0-2**8 + volumes = volumes * (2**8 - 1) / volumes.max() + + return volumes.astype("uint8").tolist() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("path", type=Path) + parser.add_argument("--segments-count", type=int, default=1000) + args = parser.parse_args() + + print(get_audio_waveform(args.path, args.segments_count)) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index f18cba6b..deaa0567 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,3 +1,5 @@ +import json +import shutil from datetime import datetime from pathlib import Path from tempfile import NamedTemporaryFile @@ -19,6 +21,7 @@ from pydantic import BaseModel, Field from reflector.db import database, transcripts from reflector.logger import logger from reflector.settings import settings +from reflector.utils.audio_waveform import get_audio_waveform from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response @@ -40,6 +43,10 @@ def generate_transcript_name(): return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +class AudioWaveform(BaseModel): + data: list[int] + + class TranscriptText(BaseModel): text: str @@ -112,10 +119,24 @@ class Transcript(BaseModel): out.close() # move temporary file to final location - import shutil - shutil.move(tmp.name, fn.as_posix()) + def convert_audio_to_waveform(self, segments_count=1000): + fn = self.audio_waveform_filename + if fn.exists(): + return + waveform = get_audio_waveform( + path=self.audio_filename, segments_count=segments_count + ) + try: + with open(fn, "w") as fd: + json.dump(waveform, fd) + except Exception: + # remove file if anything happen during the write + fn.unlink(missing_ok=True) + raise + return waveform + def unlink(self): self.data_path.unlink(missing_ok=True) @@ -131,6 +152,22 @@ class Transcript(BaseModel): def audio_mp3_filename(self): return self.data_path / "audio.mp3" + @property + def audio_waveform_filename(self): + return self.data_path / "audio.json" + + @property + def audio_waveform(self): + try: + with open(self.audio_waveform_filename) as fd: + data = json.load(fd) + except json.JSONDecodeError: + # unlink file if it's corrupted + self.audio_waveform_filename.unlink(missing_ok=True) + return None + + return AudioWaveform(data=data) + class TranscriptController: async def get_all(self, user_id: str | None = None) -> list[Transcript]: @@ -334,6 +371,24 @@ async def transcript_get_audio_mp3( ) +@router.get("/transcripts/{transcript_id}/audio/waveform") +async def transcript_get_audio_waveform( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> AudioWaveform: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + if not transcript.audio_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + await run_in_threadpool(transcript.convert_audio_to_waveform) + + return transcript.audio_waveform + + @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) async def transcript_get_topics( transcript_id: str, diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 204ed90e..2af34ee7 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -93,3 +93,15 @@ async def test_transcript_audio_download_range_with_seek( assert response.status_code == 206 assert response.headers["content-type"] == content_type assert response.headers["content-range"].startswith("bytes 100-") + + +@pytest.mark.asyncio +async def test_transcript_audio_download_waveform(fake_transcript): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + assert isinstance(response.json()["data"], list) + assert len(response.json()["data"]) == 1000