mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
73
server/reflector/utils/audio_waveform.py
Normal file
73
server/reflector/utils/audio_waveform.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_audio_waveform(path: Path | str, segments_count: int = 1000) -> list[int]:
|
||||
if isinstance(path, Path):
|
||||
path = path.as_posix()
|
||||
|
||||
container = av.open(path)
|
||||
stream = container.streams.get(audio=0)[0]
|
||||
duration = container.duration / av.time_base
|
||||
|
||||
chunk_size_secs = duration / segments_count
|
||||
chunk_size = int(chunk_size_secs * stream.rate * stream.channels)
|
||||
if chunk_size == 0:
|
||||
# there is not enough data to fill the chunks
|
||||
# so basically we use chunk_size of 1.
|
||||
chunk_size = 1
|
||||
|
||||
# 1.1 is a safety margin as it seems that pyav decode
|
||||
# does not always return the exact number of chunks
|
||||
# that we expect.
|
||||
volumes = np.zeros(int(segments_count * 1.1), dtype=int)
|
||||
current_chunk_idx = 0
|
||||
current_chunk_size = 0
|
||||
current_chunk_volume = 0
|
||||
|
||||
count = 0
|
||||
frames = 0
|
||||
samples = 0
|
||||
for frame in container.decode(stream):
|
||||
data = frame.to_ndarray().flatten()
|
||||
count += len(data)
|
||||
frames += 1
|
||||
samples += frame.samples
|
||||
|
||||
while len(data) > 0:
|
||||
datalen = len(data)
|
||||
|
||||
# check how much we need to fill the chunk
|
||||
chunk_remaining = chunk_size - current_chunk_size
|
||||
if chunk_remaining > 0:
|
||||
volume = np.absolute(data[:chunk_remaining]).max()
|
||||
data = data[chunk_remaining:]
|
||||
current_chunk_volume = max(current_chunk_volume, volume)
|
||||
current_chunk_size += min(chunk_remaining, datalen)
|
||||
|
||||
if current_chunk_size == chunk_size:
|
||||
# chunk is full, add it to the volumes
|
||||
volumes[current_chunk_idx] = current_chunk_volume
|
||||
current_chunk_idx += 1
|
||||
current_chunk_size = 0
|
||||
current_chunk_volume = 0
|
||||
|
||||
volumes = volumes[:current_chunk_idx]
|
||||
|
||||
# normalize the volumes 0-2**8
|
||||
volumes = volumes * (2**8 - 1) / volumes.max()
|
||||
|
||||
return volumes.astype("uint8").tolist()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("path", type=Path)
|
||||
parser.add_argument("--segments-count", type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(get_audio_waveform(args.path, args.segments_count))
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
@@ -19,6 +21,7 @@ from pydantic import BaseModel, Field
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
@@ -40,6 +43,10 @@ def generate_transcript_name():
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[int]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
|
||||
@@ -112,10 +119,24 @@ class Transcript(BaseModel):
|
||||
out.close()
|
||||
|
||||
# move temporary file to final location
|
||||
import shutil
|
||||
|
||||
shutil.move(tmp.name, fn.as_posix())
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=1000):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@@ -131,6 +152,22 @@ class Transcript(BaseModel):
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
||||
@@ -334,6 +371,24 @@ async def transcript_get_audio_mp3(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/waveform")
|
||||
async def transcript_get_audio_waveform(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
async def transcript_get_topics(
|
||||
transcript_id: str,
|
||||
|
||||
@@ -93,3 +93,15 @@ async def test_transcript_audio_download_range_with_seek(
|
||||
assert response.status_code == 206
|
||||
assert response.headers["content-type"] == content_type
|
||||
assert response.headers["content-range"].startswith("bytes 100-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_audio_download_waveform(fake_transcript):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
assert isinstance(response.json()["data"], list)
|
||||
assert len(response.json()["data"]) == 1000
|
||||
|
||||
Reference in New Issue
Block a user