diff --git a/server/reflector/views/_range_requests_response.py b/server/reflector/views/_range_requests_response.py new file mode 100644 index 00000000..1a584a3c --- /dev/null +++ b/server/reflector/views/_range_requests_response.py @@ -0,0 +1,72 @@ +import os +from typing import BinaryIO + +from fastapi import HTTPException, Request, status +from fastapi.responses import StreamingResponse + + +def send_bytes_range_requests( + file_obj: BinaryIO, start: int, end: int, chunk_size: int = 10_000 +): + """Send a file in chunks using Range Requests specification RFC7233 + + `start` and `end` parameters are inclusive due to specification + """ + with file_obj as f: + f.seek(start) + while (pos := f.tell()) <= end: + read_size = min(chunk_size, end + 1 - pos) + yield f.read(read_size) + + +def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]: + def _invalid_range(): + return HTTPException( + status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + detail=f"Invalid request range (Range:{range_header!r})", + ) + + try: + h = range_header.replace("bytes=", "").split("-") + start = int(h[0]) if h[0] != "" else 0 + end = int(h[1]) if h[1] != "" else file_size - 1 + except ValueError: + raise _invalid_range() + + if start > end or start < 0 or end > file_size - 1: + raise _invalid_range() + return start, end + + +def range_requests_response(request: Request, file_path: str, content_type: str): + """Returns StreamingResponse using Range Requests of a given file""" + + file_size = os.stat(file_path).st_size + range_header = request.headers.get("range") + + headers = { + "content-type": content_type, + "accept-ranges": "bytes", + "content-encoding": "identity", + "content-length": str(file_size), + "access-control-expose-headers": ( + "content-type, accept-ranges, content-length, " + "content-range, content-encoding" + ), + } + start = 0 + end = file_size - 1 + status_code = status.HTTP_200_OK + + if range_header is not None: + start, end = _get_range_header(range_header, file_size) + size = end - start + 1 + headers["content-length"] = str(size) + headers["content-range"] = f"bytes {start}-{end}/{file_size}" + status_code = status.HTTP_206_PARTIAL_CONTENT + + return StreamingResponse( + send_bytes_range_requests(open(file_path, mode="rb"), start, end), + headers=headers, + status_code=status_code, + ) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index e542973a..7321de37 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -14,7 +14,6 @@ from fastapi import ( WebSocket, WebSocketDisconnect, ) -from fastapi.responses import FileResponse from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field from reflector.db import database, transcripts @@ -22,6 +21,7 @@ from reflector.logger import logger from reflector.settings import settings from starlette.concurrency import run_in_threadpool +from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() @@ -281,6 +281,7 @@ async def transcript_delete( @router.get("/transcripts/{transcript_id}/audio") async def transcript_get_audio( + request: Request, transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): @@ -292,11 +293,16 @@ async def transcript_get_audio( if not transcript.audio_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") - return FileResponse(transcript.audio_filename, media_type="audio/wav") + return range_requests_response( + request, + transcript.audio_filename, + content_type="audio/wav", + ) @router.get("/transcripts/{transcript_id}/audio/mp3") async def transcript_get_audio_mp3( + request: Request, transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): @@ -310,7 +316,11 @@ async def transcript_get_audio_mp3( await run_in_threadpool(transcript.convert_audio_to_mp3) - return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3") + return range_requests_response( + request, + transcript.audio_mp3_filename, + content_type="audio/mp3", + ) @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py new file mode 100644 index 00000000..204ed90e --- /dev/null +++ b/server/tests/test_transcripts_audio_download.py @@ -0,0 +1,95 @@ +import pytest +import shutil +from httpx import AsyncClient +from pathlib import Path + + +@pytest.fixture +async def fake_transcript(tmpdir): + from reflector.settings import settings + from reflector.app import app + from reflector.views.transcripts import transcripts_controller + + settings.DATA_DIR = Path(tmpdir) + + # create a transcript + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.post("/transcripts", json={"name": "Test audio download"}) + assert response.status_code == 200 + tid = response.json()["id"] + + transcript = await transcripts_controller.get_by_id(tid) + assert transcript is not None + + await transcripts_controller.update(transcript, {"status": "finished"}) + + # manually copy a file at the expected location + audio_filename = transcript.audio_filename + path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" + audio_filename.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(path, audio_filename) + yield transcript + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url_suffix,content_type", + [ + ["", "audio/wav"], + ["/mp3", "audio/mp3"], + ], +) +async def test_transcript_audio_download(fake_transcript, url_suffix, content_type): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}") + assert response.status_code == 200 + assert response.headers["content-type"] == content_type + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url_suffix,content_type", + [ + ["", "audio/wav"], + ["/mp3", "audio/mp3"], + ], +) +async def test_transcript_audio_download_range( + fake_transcript, url_suffix, content_type +): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.get( + f"/transcripts/{fake_transcript.id}/audio{url_suffix}", + headers={"range": "bytes=0-100"}, + ) + assert response.status_code == 206 + assert response.headers["content-type"] == content_type + assert response.headers["content-range"].startswith("bytes 0-100/") + assert response.headers["content-length"] == "101" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url_suffix,content_type", + [ + ["", "audio/wav"], + ["/mp3", "audio/mp3"], + ], +) +async def test_transcript_audio_download_range_with_seek( + fake_transcript, url_suffix, content_type +): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.get( + f"/transcripts/{fake_transcript.id}/audio{url_suffix}", + headers={"range": "bytes=100-"}, + ) + assert response.status_code == 206 + assert response.headers["content-type"] == content_type + assert response.headers["content-range"].startswith("bytes 100-")