Files
reflector/server/reflector/views/transcripts_audio.py
Mathieu Virbel df909363f5 fix: add missing db_session parameter to transcript audio endpoints
- Add db_session parameter to transcript_get_audio_mp3 endpoint
- Fix audio_mp3_filename path conversion with .as_posix()
- Add null check for audio_waveform before returning
- Update test fixtures to properly pass db_session parameter
- Fix transcript controller calls in test_transcripts_audio_download
2025-09-23 19:05:50 -06:00

117 lines
3.7 KiB
Python

"""
Transcripts audio related endpoints
===================================
"""
from typing import Annotated, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM
from ._range_requests_response import range_requests_response
router = APIRouter()
@router.get(
"/transcripts/{transcript_id}/audio/mp3",
operation_id="transcript_get_audio_mp3",
)
@router.head(
"/transcripts/{transcript_id}/audio/mp3",
operation_id="transcript_head_audio_mp3",
)
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_deleted:
raise HTTPException(
status_code=404, detail="Audio unavailable due to privacy settings"
)
if (
not hasattr(transcript, "audio_mp3_filename")
or not transcript.audio_mp3_filename
or not transcript.audio_mp3_filename.exists()
):
raise HTTPException(status_code=404, detail="Audio file not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename.as_posix(),
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
audio_waveform = transcript.audio_waveform
if not audio_waveform:
raise HTTPException(status_code=404, detail="Audio waveform not found")
return audio_waveform