server: implement audio waveform download as json

Closes #136
This commit is contained in:
2023-08-23 12:32:43 +02:00
committed by Mathieu Virbel
parent 196aa8454f
commit c21d88b797
3 changed files with 142 additions and 2 deletions

View File

@@ -0,0 +1,73 @@
from pathlib import Path
import av
import numpy as np
def get_audio_waveform(path: Path | str, segments_count: int = 1000) -> list[int]:
if isinstance(path, Path):
path = path.as_posix()
container = av.open(path)
stream = container.streams.get(audio=0)[0]
duration = container.duration / av.time_base
chunk_size_secs = duration / segments_count
chunk_size = int(chunk_size_secs * stream.rate * stream.channels)
if chunk_size == 0:
# there is not enough data to fill the chunks
# so basically we use chunk_size of 1.
chunk_size = 1
# 1.1 is a safety margin as it seems that pyav decode
# does not always return the exact number of chunks
# that we expect.
volumes = np.zeros(int(segments_count * 1.1), dtype=int)
current_chunk_idx = 0
current_chunk_size = 0
current_chunk_volume = 0
count = 0
frames = 0
samples = 0
for frame in container.decode(stream):
data = frame.to_ndarray().flatten()
count += len(data)
frames += 1
samples += frame.samples
while len(data) > 0:
datalen = len(data)
# check how much we need to fill the chunk
chunk_remaining = chunk_size - current_chunk_size
if chunk_remaining > 0:
volume = np.absolute(data[:chunk_remaining]).max()
data = data[chunk_remaining:]
current_chunk_volume = max(current_chunk_volume, volume)
current_chunk_size += min(chunk_remaining, datalen)
if current_chunk_size == chunk_size:
# chunk is full, add it to the volumes
volumes[current_chunk_idx] = current_chunk_volume
current_chunk_idx += 1
current_chunk_size = 0
current_chunk_volume = 0
volumes = volumes[:current_chunk_idx]
# normalize the volumes 0-2**8
volumes = volumes * (2**8 - 1) / volumes.max()
return volumes.astype("uint8").tolist()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("path", type=Path)
parser.add_argument("--segments-count", type=int, default=1000)
args = parser.parse_args()
print(get_audio_waveform(args.path, args.segments_count))

View File

@@ -1,3 +1,5 @@
import json
import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@@ -19,6 +21,7 @@ from pydantic import BaseModel, Field
from reflector.db import database, transcripts from reflector.db import database, transcripts
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response from ._range_requests_response import range_requests_response
@@ -40,6 +43,10 @@ def generate_transcript_name():
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[int]
class TranscriptText(BaseModel): class TranscriptText(BaseModel):
text: str text: str
@@ -112,10 +119,24 @@ class Transcript(BaseModel):
out.close() out.close()
# move temporary file to final location # move temporary file to final location
import shutil
shutil.move(tmp.name, fn.as_posix()) shutil.move(tmp.name, fn.as_posix())
def convert_audio_to_waveform(self, segments_count=1000):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self): def unlink(self):
self.data_path.unlink(missing_ok=True) self.data_path.unlink(missing_ok=True)
@@ -131,6 +152,22 @@ class Transcript(BaseModel):
def audio_mp3_filename(self): def audio_mp3_filename(self):
return self.data_path / "audio.mp3" return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController: class TranscriptController:
async def get_all(self, user_id: str | None = None) -> list[Transcript]: async def get_all(self, user_id: str | None = None) -> list[Transcript]:
@@ -334,6 +371,24 @@ async def transcript_get_audio_mp3(
) )
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
async def transcript_get_topics( async def transcript_get_topics(
transcript_id: str, transcript_id: str,

View File

@@ -93,3 +93,15 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206 assert response.status_code == 206
assert response.headers["content-type"] == content_type assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-") assert response.headers["content-range"].startswith("bytes 100-")
@pytest.mark.asyncio
async def test_transcript_audio_download_waveform(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) == 1000