server: include endpoint to upload a audio/video file

This commit is contained in:
2023-12-12 20:39:15 +01:00
parent bcbd990958
commit e5e1b70213
8 changed files with 259 additions and 25 deletions

18
server/poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
[[package]] [[package]]
name = "aioboto3" name = "aioboto3"
@@ -2881,6 +2881,20 @@ cryptography = ["cryptography (>=3.4.0)"]
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
[[package]]
name = "python-multipart"
version = "0.0.6"
description = "A streaming multipart parser for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"},
{file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"},
]
[package.extras]
dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"]
[[package]] [[package]]
name = "pyyaml" name = "pyyaml"
version = "6.0.1" version = "6.0.1"
@@ -4219,4 +4233,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a" content-hash = "b823010302af2dcd2ece591eaf10d5cbf945f74bd0fc35fc69f3060c0c253d57"

View File

@@ -36,6 +36,7 @@ profanityfilter = "^2.0.6"
celery = "^5.3.4" celery = "^5.3.4"
redis = "^5.0.1" redis = "^5.0.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"} python-jose = {extras = ["cryptography"], version = "^3.3.0"}
python-multipart = "^0.0.6"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View File

@@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_participants import ( from reflector.views.transcripts_participants import (
router as transcripts_participants_router, router as transcripts_participants_router,
) )
from reflector.views.transcripts_upload import router as transcripts_upload_router
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router from reflector.views.user import router as user_router
@@ -68,6 +69,7 @@ app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_router, prefix="/v1")
app.include_router(transcripts_audio_router, prefix="/v1") app.include_router(transcripts_audio_router, prefix="/v1")
app.include_router(transcripts_participants_router, prefix="/v1") app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_upload_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")

View File

@@ -618,3 +618,47 @@ def pipeline_post(*, transcript_id: str):
chain_final_summaries, chain_final_summaries,
) )
chain.delay() chain.delay()
@get_transcript
async def pipeline_upload(transcript: Transcript, logger: Logger):
import av
try:
# open audio
upload_filename = next(transcript.data_path.glob("upload.*"))
container = av.open(upload_filename.as_posix())
# create pipeline
pipeline = PipelineMainLive(transcript_id=transcript.id)
pipeline.start()
# push audio to pipeline
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
pipeline.flush()
logger.info("Waiting for the pipeline to end")
await pipeline.join()
except Exception as exc:
logger.error("Pipeline error", exc_info=exc)
await transcripts_controller.update(
transcript,
{
"status": "error",
},
)
raise
logger.info("Pipeline ended")
@shared_task
@asynctask
async def task_pipeline_upload(*, transcript_id: str):
return await pipeline_upload(transcript_id=transcript_id)

View File

@@ -30,7 +30,8 @@ class PipelineRunner(BaseModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._q_cmd = asyncio.Queue() self._task = None
self._q_cmd = asyncio.Queue(maxsize=4096)
self._ev_done = asyncio.Event() self._ev_done = asyncio.Event()
self._is_first_push = True self._is_first_push = True
self._logger = logger.bind( self._logger = logger.bind(
@@ -49,7 +50,14 @@ class PipelineRunner(BaseModel):
""" """
Start the pipeline as a coroutine task Start the pipeline as a coroutine task
""" """
asyncio.get_event_loop().create_task(self.run()) self._task = asyncio.get_event_loop().create_task(self.run())
async def join(self):
"""
Wait for the pipeline to finish
"""
if self._task:
await self._task
def start_sync(self): def start_sync(self):
""" """

View File

@@ -0,0 +1,79 @@
from typing import Annotated, Optional
import av
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
router = APIRouter()
class UploadStatus(BaseModel):
status: str
@router.post("/transcripts/{transcript_id}/record/upload")
async def transcript_record_upload(
transcript_id: str,
file: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# ensure there is no other upload in the directory (searching data_path/upload.*)
if any(transcript.data_path.glob("upload.*")):
raise HTTPException(
status_code=400, detail="There is already an upload in progress"
)
# save the file to the transcript folder
extension = file.filename.split(".")[-1]
upload_filename = transcript.data_path / f"upload.{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True)
# ensure the file is back to the beginning
await file.seek(0)
# save the file to the transcript folder
try:
with open(upload_filename, "wb") as f:
while True:
chunk = await file.read(16384)
if not chunk:
break
f.write(chunk)
except Exception:
upload_filename.unlink()
raise
# ensure the file have audio part, using av
# XXX Trying to do this check on the initial UploadFile object is not
# possible, dunno why. UploadFile.file has no name.
# Trying to pass UploadFile.file with format=extension does not work
# it never detect audio stream...
container = av.open(upload_filename.as_posix())
try:
if not len(container.streams.audio):
raise HTTPException(status_code=400, detail="File has no audio stream")
except Exception:
# delete the uploaded file
upload_filename.unlink()
raise
finally:
container.close()
# set the status to "uploaded"
await transcripts_controller.update(transcript, {"status": "uploaded"})
# launch a background task to process the file
task_pipeline_upload.delay(transcript_id=transcript_id)
return UploadStatus(status="ok")

View File

@@ -1,8 +1,13 @@
import celery
import structlog import structlog
from celery import Celery from celery import Celery
from reflector.settings import settings from reflector.settings import settings
logger = structlog.get_logger(__name__) logger = structlog.get_logger(__name__)
if celery.current_app is not None:
logger.info(f"Celery already configured ({celery.current_app})")
app = celery.current_app
else:
app = Celery(__name__) app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND app.conf.result_backend = settings.CELERY_RESULT_BACKEND

View File

@@ -0,0 +1,81 @@
import pytest
import asyncio
from httpx import AsyncClient
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
@pytest.fixture(scope="session")
def celery_config():
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_app")
@pytest.mark.usefixtures("celery_worker")
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,
dummy_llm,
dummy_processors,
dummy_diarization,
dummy_storage,
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["status"] == "idle"
tid = response.json()["id"]
# upload mp3
response = await ac.post(
f"/transcripts/{tid}/record/upload",
files={
"file": (
"test_mathieu_hello.mp3",
open("tests/records/test_mathieu_hello.mp3", "rb"),
"audio/mpeg",
)
},
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait the processing to finish
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "LLM TITLE"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"]