mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
72
server/reflector/views/_range_requests_response.py
Normal file
72
server/reflector/views/_range_requests_response.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
def send_bytes_range_requests(
|
||||||
|
file_obj: BinaryIO, start: int, end: int, chunk_size: int = 10_000
|
||||||
|
):
|
||||||
|
"""Send a file in chunks using Range Requests specification RFC7233
|
||||||
|
|
||||||
|
`start` and `end` parameters are inclusive due to specification
|
||||||
|
"""
|
||||||
|
with file_obj as f:
|
||||||
|
f.seek(start)
|
||||||
|
while (pos := f.tell()) <= end:
|
||||||
|
read_size = min(chunk_size, end + 1 - pos)
|
||||||
|
yield f.read(read_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
|
||||||
|
def _invalid_range():
|
||||||
|
return HTTPException(
|
||||||
|
status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
|
||||||
|
detail=f"Invalid request range (Range:{range_header!r})",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
h = range_header.replace("bytes=", "").split("-")
|
||||||
|
start = int(h[0]) if h[0] != "" else 0
|
||||||
|
end = int(h[1]) if h[1] != "" else file_size - 1
|
||||||
|
except ValueError:
|
||||||
|
raise _invalid_range()
|
||||||
|
|
||||||
|
if start > end or start < 0 or end > file_size - 1:
|
||||||
|
raise _invalid_range()
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def range_requests_response(request: Request, file_path: str, content_type: str):
|
||||||
|
"""Returns StreamingResponse using Range Requests of a given file"""
|
||||||
|
|
||||||
|
file_size = os.stat(file_path).st_size
|
||||||
|
range_header = request.headers.get("range")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"content-type": content_type,
|
||||||
|
"accept-ranges": "bytes",
|
||||||
|
"content-encoding": "identity",
|
||||||
|
"content-length": str(file_size),
|
||||||
|
"access-control-expose-headers": (
|
||||||
|
"content-type, accept-ranges, content-length, "
|
||||||
|
"content-range, content-encoding"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
start = 0
|
||||||
|
end = file_size - 1
|
||||||
|
status_code = status.HTTP_200_OK
|
||||||
|
|
||||||
|
if range_header is not None:
|
||||||
|
start, end = _get_range_header(range_header, file_size)
|
||||||
|
size = end - start + 1
|
||||||
|
headers["content-length"] = str(size)
|
||||||
|
headers["content-range"] = f"bytes {start}-{end}/{file_size}"
|
||||||
|
status_code = status.HTTP_206_PARTIAL_CONTENT
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
send_bytes_range_requests(open(file_path, mode="rb"), start, end),
|
||||||
|
headers=headers,
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
@@ -14,7 +14,6 @@ from fastapi import (
|
|||||||
WebSocket,
|
WebSocket,
|
||||||
WebSocketDisconnect,
|
WebSocketDisconnect,
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from fastapi_pagination import Page, paginate
|
from fastapi_pagination import Page, paginate
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from reflector.db import database, transcripts
|
from reflector.db import database, transcripts
|
||||||
@@ -22,6 +21,7 @@ from reflector.logger import logger
|
|||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
|
from ._range_requests_response import range_requests_response
|
||||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -281,6 +281,7 @@ async def transcript_delete(
|
|||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/audio")
|
@router.get("/transcripts/{transcript_id}/audio")
|
||||||
async def transcript_get_audio(
|
async def transcript_get_audio(
|
||||||
|
request: Request,
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
):
|
):
|
||||||
@@ -292,11 +293,16 @@ async def transcript_get_audio(
|
|||||||
if not transcript.audio_filename.exists():
|
if not transcript.audio_filename.exists():
|
||||||
raise HTTPException(status_code=404, detail="Audio not found")
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
return FileResponse(transcript.audio_filename, media_type="audio/wav")
|
return range_requests_response(
|
||||||
|
request,
|
||||||
|
transcript.audio_filename,
|
||||||
|
content_type="audio/wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||||
async def transcript_get_audio_mp3(
|
async def transcript_get_audio_mp3(
|
||||||
|
request: Request,
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
):
|
):
|
||||||
@@ -310,7 +316,11 @@ async def transcript_get_audio_mp3(
|
|||||||
|
|
||||||
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
await run_in_threadpool(transcript.convert_audio_to_mp3)
|
||||||
|
|
||||||
return FileResponse(transcript.audio_mp3_filename, media_type="audio/mp3")
|
return range_requests_response(
|
||||||
|
request,
|
||||||
|
transcript.audio_mp3_filename,
|
||||||
|
content_type="audio/mp3",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||||
|
|||||||
95
server/tests/test_transcripts_audio_download.py
Normal file
95
server/tests/test_transcripts_audio_download.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import pytest
|
||||||
|
import shutil
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def fake_transcript(tmpdir):
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.views.transcripts import transcripts_controller
|
||||||
|
|
||||||
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
|
# create a transcript
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
response = await ac.post("/transcripts", json={"name": "Test audio download"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(tid)
|
||||||
|
assert transcript is not None
|
||||||
|
|
||||||
|
await transcripts_controller.update(transcript, {"status": "finished"})
|
||||||
|
|
||||||
|
# manually copy a file at the expected location
|
||||||
|
audio_filename = transcript.audio_filename
|
||||||
|
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
audio_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(path, audio_filename)
|
||||||
|
yield transcript
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url_suffix,content_type",
|
||||||
|
[
|
||||||
|
["", "audio/wav"],
|
||||||
|
["/mp3", "audio/mp3"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == content_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url_suffix,content_type",
|
||||||
|
[
|
||||||
|
["", "audio/wav"],
|
||||||
|
["/mp3", "audio/mp3"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_transcript_audio_download_range(
|
||||||
|
fake_transcript, url_suffix, content_type
|
||||||
|
):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
response = await ac.get(
|
||||||
|
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
|
||||||
|
headers={"range": "bytes=0-100"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 206
|
||||||
|
assert response.headers["content-type"] == content_type
|
||||||
|
assert response.headers["content-range"].startswith("bytes 0-100/")
|
||||||
|
assert response.headers["content-length"] == "101"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url_suffix,content_type",
|
||||||
|
[
|
||||||
|
["", "audio/wav"],
|
||||||
|
["/mp3", "audio/mp3"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_transcript_audio_download_range_with_seek(
|
||||||
|
fake_transcript, url_suffix, content_type
|
||||||
|
):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
response = await ac.get(
|
||||||
|
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
|
||||||
|
headers={"range": "bytes=100-"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 206
|
||||||
|
assert response.headers["content-type"] == content_type
|
||||||
|
assert response.headers["content-range"].startswith("bytes 100-")
|
||||||
Reference in New Issue
Block a user