diff --git a/server/reflector/views/_range_requests_response.py b/server/reflector/views/_range_requests_response.py index f0c628e9..2fac632d 100644 --- a/server/reflector/views/_range_requests_response.py +++ b/server/reflector/views/_range_requests_response.py @@ -1,7 +1,7 @@ import os from typing import BinaryIO -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request, Response, status from fastapi.responses import StreamingResponse @@ -57,6 +57,9 @@ def range_requests_response( ), } + if request.method == "HEAD": + return Response(headers=headers) + if content_disposition: headers["Content-Disposition"] = content_disposition diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index e3668ecb..a202f3a1 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -210,6 +210,7 @@ async def transcript_delete( @router.get("/transcripts/{transcript_id}/audio/mp3") +@router.head("/transcripts/{transcript_id}/audio/mp3") async def transcript_get_audio_mp3( request: Request, transcript_id: str, diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 79cb25bf..69ae5f65 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -46,6 +46,34 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty assert response.status_code == 200 assert response.headers["content-type"] == content_type + # test get 404 + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}") + assert response.status_code == 404 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url_suffix,content_type", + [ + ["/mp3", "audio/mpeg"], + ], +) +async def test_transcript_audio_download_head( + fake_transcript, url_suffix, content_type +): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}") + assert response.status_code == 200 + assert response.headers["content-type"] == content_type + + # test head 404 + ac = AsyncClient(app=app, base_url="http://test/v1") + response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}") + assert response.status_code == 404 + @pytest.mark.asyncio @pytest.mark.parametrize(