mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
73
server/reflector/utils/audio_waveform.py
Normal file
73
server/reflector/utils/audio_waveform.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_waveform(path: Path | str, segments_count: int = 1000) -> list[int]:
|
||||||
|
if isinstance(path, Path):
|
||||||
|
path = path.as_posix()
|
||||||
|
|
||||||
|
container = av.open(path)
|
||||||
|
stream = container.streams.get(audio=0)[0]
|
||||||
|
duration = container.duration / av.time_base
|
||||||
|
|
||||||
|
chunk_size_secs = duration / segments_count
|
||||||
|
chunk_size = int(chunk_size_secs * stream.rate * stream.channels)
|
||||||
|
if chunk_size == 0:
|
||||||
|
# there is not enough data to fill the chunks
|
||||||
|
# so basically we use chunk_size of 1.
|
||||||
|
chunk_size = 1
|
||||||
|
|
||||||
|
# 1.1 is a safety margin as it seems that pyav decode
|
||||||
|
# does not always return the exact number of chunks
|
||||||
|
# that we expect.
|
||||||
|
volumes = np.zeros(int(segments_count * 1.1), dtype=int)
|
||||||
|
current_chunk_idx = 0
|
||||||
|
current_chunk_size = 0
|
||||||
|
current_chunk_volume = 0
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
frames = 0
|
||||||
|
samples = 0
|
||||||
|
for frame in container.decode(stream):
|
||||||
|
data = frame.to_ndarray().flatten()
|
||||||
|
count += len(data)
|
||||||
|
frames += 1
|
||||||
|
samples += frame.samples
|
||||||
|
|
||||||
|
while len(data) > 0:
|
||||||
|
datalen = len(data)
|
||||||
|
|
||||||
|
# check how much we need to fill the chunk
|
||||||
|
chunk_remaining = chunk_size - current_chunk_size
|
||||||
|
if chunk_remaining > 0:
|
||||||
|
volume = np.absolute(data[:chunk_remaining]).max()
|
||||||
|
data = data[chunk_remaining:]
|
||||||
|
current_chunk_volume = max(current_chunk_volume, volume)
|
||||||
|
current_chunk_size += min(chunk_remaining, datalen)
|
||||||
|
|
||||||
|
if current_chunk_size == chunk_size:
|
||||||
|
# chunk is full, add it to the volumes
|
||||||
|
volumes[current_chunk_idx] = current_chunk_volume
|
||||||
|
current_chunk_idx += 1
|
||||||
|
current_chunk_size = 0
|
||||||
|
current_chunk_volume = 0
|
||||||
|
|
||||||
|
volumes = volumes[:current_chunk_idx]
|
||||||
|
|
||||||
|
# normalize the volumes 0-2**8
|
||||||
|
volumes = volumes * (2**8 - 1) / volumes.max()
|
||||||
|
|
||||||
|
return volumes.astype("uint8").tolist()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("path", type=Path)
|
||||||
|
parser.add_argument("--segments-count", type=int, default=1000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(get_audio_waveform(args.path, args.segments_count))
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
import shutil
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
@@ -19,6 +21,7 @@ from pydantic import BaseModel, Field
|
|||||||
from reflector.db import database, transcripts
|
from reflector.db import database, transcripts
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
from ._range_requests_response import range_requests_response
|
from ._range_requests_response import range_requests_response
|
||||||
@@ -40,6 +43,10 @@ def generate_transcript_name():
|
|||||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioWaveform(BaseModel):
|
||||||
|
data: list[int]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptText(BaseModel):
|
class TranscriptText(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
@@ -112,10 +119,24 @@ class Transcript(BaseModel):
|
|||||||
out.close()
|
out.close()
|
||||||
|
|
||||||
# move temporary file to final location
|
# move temporary file to final location
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.move(tmp.name, fn.as_posix())
|
shutil.move(tmp.name, fn.as_posix())
|
||||||
|
|
||||||
|
def convert_audio_to_waveform(self, segments_count=1000):
|
||||||
|
fn = self.audio_waveform_filename
|
||||||
|
if fn.exists():
|
||||||
|
return
|
||||||
|
waveform = get_audio_waveform(
|
||||||
|
path=self.audio_filename, segments_count=segments_count
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with open(fn, "w") as fd:
|
||||||
|
json.dump(waveform, fd)
|
||||||
|
except Exception:
|
||||||
|
# remove file if anything happen during the write
|
||||||
|
fn.unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
return waveform
|
||||||
|
|
||||||
def unlink(self):
|
def unlink(self):
|
||||||
self.data_path.unlink(missing_ok=True)
|
self.data_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
@@ -131,6 +152,22 @@ class Transcript(BaseModel):
|
|||||||
def audio_mp3_filename(self):
|
def audio_mp3_filename(self):
|
||||||
return self.data_path / "audio.mp3"
|
return self.data_path / "audio.mp3"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_waveform_filename(self):
|
||||||
|
return self.data_path / "audio.json"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_waveform(self):
|
||||||
|
try:
|
||||||
|
with open(self.audio_waveform_filename) as fd:
|
||||||
|
data = json.load(fd)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# unlink file if it's corrupted
|
||||||
|
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AudioWaveform(data=data)
|
||||||
|
|
||||||
|
|
||||||
class TranscriptController:
|
class TranscriptController:
|
||||||
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
||||||
@@ -334,6 +371,24 @@ async def transcript_get_audio_mp3(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/transcripts/{transcript_id}/audio/waveform")
|
||||||
|
async def transcript_get_audio_waveform(
|
||||||
|
transcript_id: str,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
) -> AudioWaveform:
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
|
if not transcript:
|
||||||
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
if not transcript.audio_filename.exists():
|
||||||
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
|
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||||
|
|
||||||
|
return transcript.audio_waveform
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||||
async def transcript_get_topics(
|
async def transcript_get_topics(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
|
|||||||
@@ -93,3 +93,15 @@ async def test_transcript_audio_download_range_with_seek(
|
|||||||
assert response.status_code == 206
|
assert response.status_code == 206
|
||||||
assert response.headers["content-type"] == content_type
|
assert response.headers["content-type"] == content_type
|
||||||
assert response.headers["content-range"].startswith("bytes 100-")
|
assert response.headers["content-range"].startswith("bytes 100-")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_audio_download_waveform(fake_transcript):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "application/json"
|
||||||
|
assert isinstance(response.json()["data"], list)
|
||||||
|
assert len(response.json()["data"]) == 1000
|
||||||
|
|||||||
Reference in New Issue
Block a user