server: implement audio waveform download as json

Closes #136
This commit is contained in:
2023-08-23 12:32:43 +02:00
committed by Mathieu Virbel
parent 196aa8454f
commit c21d88b797
3 changed files with 142 additions and 2 deletions

View File

@@ -0,0 +1,73 @@
from pathlib import Path
import av
import numpy as np
def get_audio_waveform(path: Path | str, segments_count: int = 1000) -> list[int]:
if isinstance(path, Path):
path = path.as_posix()
container = av.open(path)
stream = container.streams.get(audio=0)[0]
duration = container.duration / av.time_base
chunk_size_secs = duration / segments_count
chunk_size = int(chunk_size_secs * stream.rate * stream.channels)
if chunk_size == 0:
# there is not enough data to fill the chunks
# so basically we use chunk_size of 1.
chunk_size = 1
# 1.1 is a safety margin as it seems that pyav decode
# does not always return the exact number of chunks
# that we expect.
volumes = np.zeros(int(segments_count * 1.1), dtype=int)
current_chunk_idx = 0
current_chunk_size = 0
current_chunk_volume = 0
count = 0
frames = 0
samples = 0
for frame in container.decode(stream):
data = frame.to_ndarray().flatten()
count += len(data)
frames += 1
samples += frame.samples
while len(data) > 0:
datalen = len(data)
# check how much we need to fill the chunk
chunk_remaining = chunk_size - current_chunk_size
if chunk_remaining > 0:
volume = np.absolute(data[:chunk_remaining]).max()
data = data[chunk_remaining:]
current_chunk_volume = max(current_chunk_volume, volume)
current_chunk_size += min(chunk_remaining, datalen)
if current_chunk_size == chunk_size:
# chunk is full, add it to the volumes
volumes[current_chunk_idx] = current_chunk_volume
current_chunk_idx += 1
current_chunk_size = 0
current_chunk_volume = 0
volumes = volumes[:current_chunk_idx]
# normalize the volumes 0-2**8
volumes = volumes * (2**8 - 1) / volumes.max()
return volumes.astype("uint8").tolist()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("path", type=Path)
parser.add_argument("--segments-count", type=int, default=1000)
args = parser.parse_args()
print(get_audio_waveform(args.path, args.segments_count))

View File

@@ -1,3 +1,5 @@
import json
import shutil
from datetime import datetime
from pathlib import Path
from tempfile import NamedTemporaryFile
@@ -19,6 +21,7 @@ from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
@@ -40,6 +43,10 @@ def generate_transcript_name():
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
class AudioWaveform(BaseModel):
data: list[int]
class TranscriptText(BaseModel):
text: str
@@ -112,10 +119,24 @@ class Transcript(BaseModel):
out.close()
# move temporary file to final location
import shutil
shutil.move(tmp.name, fn.as_posix())
def convert_audio_to_waveform(self, segments_count=1000):
fn = self.audio_waveform_filename
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self):
self.data_path.unlink(missing_ok=True)
@@ -131,6 +152,22 @@ class Transcript(BaseModel):
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@property
def audio_waveform_filename(self):
return self.data_path / "audio.json"
@property
def audio_waveform(self):
try:
with open(self.audio_waveform_filename) as fd:
data = json.load(fd)
except json.JSONDecodeError:
# unlink file if it's corrupted
self.audio_waveform_filename.unlink(missing_ok=True)
return None
return AudioWaveform(data=data)
class TranscriptController:
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
@@ -334,6 +371,24 @@ async def transcript_get_audio_mp3(
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
async def transcript_get_topics(
transcript_id: str,

View File

@@ -93,3 +93,15 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")
@pytest.mark.asyncio
async def test_transcript_audio_download_waveform(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) == 1000