mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
@@ -19,6 +21,7 @@ from pydantic import BaseModel, Field
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
@@ -40,6 +43,10 @@ def generate_transcript_name():
|
||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
class AudioWaveform(BaseModel):
|
||||
data: list[int]
|
||||
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
|
||||
@@ -112,10 +119,24 @@ class Transcript(BaseModel):
|
||||
out.close()
|
||||
|
||||
# move temporary file to final location
|
||||
import shutil
|
||||
|
||||
shutil.move(tmp.name, fn.as_posix())
|
||||
|
||||
def convert_audio_to_waveform(self, segments_count=1000):
|
||||
fn = self.audio_waveform_filename
|
||||
if fn.exists():
|
||||
return
|
||||
waveform = get_audio_waveform(
|
||||
path=self.audio_filename, segments_count=segments_count
|
||||
)
|
||||
try:
|
||||
with open(fn, "w") as fd:
|
||||
json.dump(waveform, fd)
|
||||
except Exception:
|
||||
# remove file if anything happen during the write
|
||||
fn.unlink(missing_ok=True)
|
||||
raise
|
||||
return waveform
|
||||
|
||||
def unlink(self):
|
||||
self.data_path.unlink(missing_ok=True)
|
||||
|
||||
@@ -131,6 +152,22 @@ class Transcript(BaseModel):
|
||||
def audio_mp3_filename(self):
|
||||
return self.data_path / "audio.mp3"
|
||||
|
||||
@property
|
||||
def audio_waveform_filename(self):
|
||||
return self.data_path / "audio.json"
|
||||
|
||||
@property
|
||||
def audio_waveform(self):
|
||||
try:
|
||||
with open(self.audio_waveform_filename) as fd:
|
||||
data = json.load(fd)
|
||||
except json.JSONDecodeError:
|
||||
# unlink file if it's corrupted
|
||||
self.audio_waveform_filename.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
return AudioWaveform(data=data)
|
||||
|
||||
|
||||
class TranscriptController:
|
||||
async def get_all(self, user_id: str | None = None) -> list[Transcript]:
|
||||
@@ -334,6 +371,24 @@ async def transcript_get_audio_mp3(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/audio/waveform")
|
||||
async def transcript_get_audio_waveform(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
if not transcript.audio_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
|
||||
return transcript.audio_waveform
|
||||
|
||||
|
||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||
async def transcript_get_topics(
|
||||
transcript_id: str,
|
||||
|
||||
Reference in New Issue
Block a user