server: support Range HTTP header for audio download

Closes #178
This commit is contained in:
2023-08-22 11:37:49 +02:00
committed by Mathieu Virbel
parent a91c453e41
commit 466d3670a1
3 changed files with 180 additions and 3 deletions

View File

@@ -0,0 +1,72 @@
import os
from typing import BinaryIO
from fastapi import HTTPException, Request, status
from fastapi.responses import StreamingResponse
def send_bytes_range_requests(
file_obj: BinaryIO, start: int, end: int, chunk_size: int = 10_000
):
"""Send a file in chunks using Range Requests specification RFC7233
`start` and `end` parameters are inclusive due to specification
"""
with file_obj as f:
f.seek(start)
while (pos := f.tell()) <= end:
read_size = min(chunk_size, end + 1 - pos)
yield f.read(read_size)
def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def _invalid_range():
return HTTPException(
status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
detail=f"Invalid request range (Range:{range_header!r})",
)
try:
h = range_header.replace("bytes=", "").split("-")
start = int(h[0]) if h[0] != "" else 0
end = int(h[1]) if h[1] != "" else file_size - 1
except ValueError:
raise _invalid_range()
if start > end or start < 0 or end > file_size - 1:
raise _invalid_range()
return start, end
def range_requests_response(request: Request, file_path: str, content_type: str):
"""Returns StreamingResponse using Range Requests of a given file"""
file_size = os.stat(file_path).st_size
range_header = request.headers.get("range")
headers = {
"content-type": content_type,
"accept-ranges": "bytes",
"content-encoding": "identity",
"content-length": str(file_size),
"access-control-expose-headers": (
"content-type, accept-ranges, content-length, "
"content-range, content-encoding"
),
}
start = 0
end = file_size - 1
status_code = status.HTTP_200_OK
if range_header is not None:
start, end = _get_range_header(range_header, file_size)
size = end - start + 1
headers["content-length"] = str(size)
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
status_code = status.HTTP_206_PARTIAL_CONTENT
return StreamingResponse(
send_bytes_range_requests(open(file_path, mode="rb"), start, end),
headers=headers,
status_code=status_code,
)

View File

@@ -14,7 +14,6 @@ from fastapi import (
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import FileResponse
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
@@ -22,6 +21,7 @@ from reflector.logger import logger
from reflector.settings import settings
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
@@ -281,6 +281,7 @@ async def transcript_delete(
@router.get("/transcripts/{transcript_id}/audio")
async def transcript_get_audio(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
@@ -292,11 +293,16 @@ async def transcript_get_audio(
if not transcript.audio_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return FileResponse(transcript.audio_filename, media_type="audio/wav")
return range_requests_response(
request,
transcript.audio_filename,
content_type="audio/wav",
)
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
@@ -310,7 +316,11 @@ async def transcript_get_audio_mp3(
await run_in_threadpool(transcript.convert_audio_to_mp3)
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mp3",
)
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])

View File

@@ -0,0 +1,95 @@
import pytest
import shutil
from httpx import AsyncClient
from pathlib import Path
@pytest.fixture
async def fake_transcript(tmpdir):
from reflector.settings import settings
from reflector.app import app
from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(transcript, {"status": "finished"})
# manually copy a file at the expected location
audio_filename = transcript.audio_filename
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
audio_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, audio_filename)
yield transcript
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["", "audio/wav"],
["/mp3", "audio/mp3"],
],
)
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["", "audio/wav"],
["/mp3", "audio/mp3"],
],
)
async def test_transcript_audio_download_range(
fake_transcript, url_suffix, content_type
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=0-100"},
)
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 0-100/")
assert response.headers["content-length"] == "101"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url_suffix,content_type",
[
["", "audio/wav"],
["/mp3", "audio/mp3"],
],
)
async def test_transcript_audio_download_range_with_seek(
fake_transcript, url_suffix, content_type
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=100-"},
)
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")