From e5e1b70213aac9b45a194ea87fe319ca2c81a9c6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 12 Dec 2023 20:39:15 +0100 Subject: [PATCH] server: include endpoint to upload a audio/video file --- server/poetry.lock | 18 ++++- server/pyproject.toml | 1 + server/reflector/app.py | 2 + .../reflector/pipelines/main_live_pipeline.py | 44 ++++++++++ server/reflector/pipelines/runner.py | 12 ++- server/reflector/views/transcripts_upload.py | 79 ++++++++++++++++++ server/reflector/worker/app.py | 47 ++++++----- server/tests/test_transcripts_upload.py | 81 +++++++++++++++++++ 8 files changed, 259 insertions(+), 25 deletions(-) create mode 100644 server/reflector/views/transcripts_upload.py create mode 100644 server/tests/test_transcripts_upload.py diff --git a/server/poetry.lock b/server/poetry.lock index b89cf400..5e52e681 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aioboto3" @@ -2881,6 +2881,20 @@ cryptography = ["cryptography (>=3.4.0)"] pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] +[[package]] +name = "python-multipart" +version = "0.0.6" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"}, + {file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"}, +] + +[package.extras] +dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -4219,4 +4233,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a" +content-hash = "b823010302af2dcd2ece591eaf10d5cbf945f74bd0fc35fc69f3060c0c253d57" diff --git a/server/pyproject.toml b/server/pyproject.toml index 2a901918..a4b0fa43 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -36,6 +36,7 @@ profanityfilter = "^2.0.6" celery = "^5.3.4" redis = "^5.0.1" python-jose = {extras = ["cryptography"], version = "^3.3.0"} +python-multipart = "^0.0.6" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index 8f45efd5..9235a578 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router from reflector.views.transcripts_participants import ( router as transcripts_participants_router, ) +from reflector.views.transcripts_upload import router as transcripts_upload_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.user import router as user_router @@ -68,6 +69,7 @@ app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_audio_router, prefix="/v1") app.include_router(transcripts_participants_router, prefix="/v1") +app.include_router(transcripts_upload_router, prefix="/v1") app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(user_router, prefix="/v1") diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index b182f421..8f45cafe 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -618,3 +618,47 @@ def pipeline_post(*, transcript_id: str): chain_final_summaries, ) chain.delay() + + +@get_transcript +async def pipeline_upload(transcript: Transcript, logger: Logger): + import av + + try: + # open audio + upload_filename = next(transcript.data_path.glob("upload.*")) + container = av.open(upload_filename.as_posix()) + + # create pipeline + pipeline = PipelineMainLive(transcript_id=transcript.id) + pipeline.start() + + # push audio to pipeline + try: + logger.info("Start pushing audio into the pipeline") + for frame in container.decode(audio=0): + pipeline.push(frame) + finally: + logger.info("Flushing the pipeline") + pipeline.flush() + + logger.info("Waiting for the pipeline to end") + await pipeline.join() + + except Exception as exc: + logger.error("Pipeline error", exc_info=exc) + await transcripts_controller.update( + transcript, + { + "status": "error", + }, + ) + raise + + logger.info("Pipeline ended") + + +@shared_task +@asynctask +async def task_pipeline_upload(*, transcript_id: str): + return await pipeline_upload(transcript_id=transcript_id) diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 708a4265..0edf156c 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -30,7 +30,8 @@ class PipelineRunner(BaseModel): def __init__(self, **kwargs): super().__init__(**kwargs) - self._q_cmd = asyncio.Queue() + self._task = None + self._q_cmd = asyncio.Queue(maxsize=4096) self._ev_done = asyncio.Event() self._is_first_push = True self._logger = logger.bind( @@ -49,7 +50,14 @@ class PipelineRunner(BaseModel): """ Start the pipeline as a coroutine task """ - asyncio.get_event_loop().create_task(self.run()) + self._task = asyncio.get_event_loop().create_task(self.run()) + + async def join(self): + """ + Wait for the pipeline to finish + """ + if self._task: + await self._task def start_sync(self): """ diff --git a/server/reflector/views/transcripts_upload.py b/server/reflector/views/transcripts_upload.py new file mode 100644 index 00000000..96b82d78 --- /dev/null +++ b/server/reflector/views/transcripts_upload.py @@ -0,0 +1,79 @@ +from typing import Annotated, Optional + +import av +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, UploadFile +from pydantic import BaseModel +from reflector.db.transcripts import transcripts_controller +from reflector.pipelines.main_live_pipeline import task_pipeline_upload + +router = APIRouter() + + +class UploadStatus(BaseModel): + status: str + + +@router.post("/transcripts/{transcript_id}/record/upload") +async def transcript_record_upload( + transcript_id: str, + file: UploadFile, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.locked: + raise HTTPException(status_code=400, detail="Transcript is locked") + + # ensure there is no other upload in the directory (searching data_path/upload.*) + if any(transcript.data_path.glob("upload.*")): + raise HTTPException( + status_code=400, detail="There is already an upload in progress" + ) + + # save the file to the transcript folder + extension = file.filename.split(".")[-1] + upload_filename = transcript.data_path / f"upload.{extension}" + upload_filename.parent.mkdir(parents=True, exist_ok=True) + + # ensure the file is back to the beginning + await file.seek(0) + + # save the file to the transcript folder + try: + with open(upload_filename, "wb") as f: + while True: + chunk = await file.read(16384) + if not chunk: + break + f.write(chunk) + except Exception: + upload_filename.unlink() + raise + + # ensure the file have audio part, using av + # XXX Trying to do this check on the initial UploadFile object is not + # possible, dunno why. UploadFile.file has no name. + # Trying to pass UploadFile.file with format=extension does not work + # it never detect audio stream... + container = av.open(upload_filename.as_posix()) + try: + if not len(container.streams.audio): + raise HTTPException(status_code=400, detail="File has no audio stream") + except Exception: + # delete the uploaded file + upload_filename.unlink() + raise + finally: + container.close() + + # set the status to "uploaded" + await transcripts_controller.update(transcript, {"status": "uploaded"}) + + # launch a background task to process the file + task_pipeline_upload.delay(transcript_id=transcript_id) + + return UploadStatus(status="ok") diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py index 689623ce..7df0423e 100644 --- a/server/reflector/worker/app.py +++ b/server/reflector/worker/app.py @@ -1,27 +1,32 @@ +import celery import structlog from celery import Celery from reflector.settings import settings logger = structlog.get_logger(__name__) -app = Celery(__name__) -app.conf.broker_url = settings.CELERY_BROKER_URL -app.conf.result_backend = settings.CELERY_RESULT_BACKEND -app.conf.broker_connection_retry_on_startup = True -app.autodiscover_tasks( - [ - "reflector.pipelines.main_live_pipeline", - "reflector.worker.healthcheck", - ] -) - -# crontab -app.conf.beat_schedule = {} - -if settings.HEALTHCHECK_URL: - app.conf.beat_schedule["healthcheck_ping"] = { - "task": "reflector.worker.healthcheck.healthcheck_ping", - "schedule": 60.0 * 10, - } - logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL) +if celery.current_app is not None: + logger.info(f"Celery already configured ({celery.current_app})") + app = celery.current_app else: - logger.warning("Healthcheck disabled, no url configured") + app = Celery(__name__) + app.conf.broker_url = settings.CELERY_BROKER_URL + app.conf.result_backend = settings.CELERY_RESULT_BACKEND + app.conf.broker_connection_retry_on_startup = True + app.autodiscover_tasks( + [ + "reflector.pipelines.main_live_pipeline", + "reflector.worker.healthcheck", + ] + ) + + # crontab + app.conf.beat_schedule = {} + + if settings.HEALTHCHECK_URL: + app.conf.beat_schedule["healthcheck_ping"] = { + "task": "reflector.worker.healthcheck.healthcheck_ping", + "schedule": 60.0 * 10, + } + logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL) + else: + logger.warning("Healthcheck disabled, no url configured") diff --git a/server/tests/test_transcripts_upload.py b/server/tests/test_transcripts_upload.py new file mode 100644 index 00000000..5414b498 --- /dev/null +++ b/server/tests/test_transcripts_upload.py @@ -0,0 +1,81 @@ +import pytest +import asyncio +from httpx import AsyncClient + + +@pytest.fixture(scope="session") +def celery_enable_logging(): + return True + + +@pytest.fixture(scope="session") +def celery_config(): + from tempfile import NamedTemporaryFile + + with NamedTemporaryFile() as f: + yield { + "broker_url": "memory://", + "result_backend": f"db+sqlite:///{f.name}", + } + + +@pytest.fixture(scope="session") +def celery_includes(): + return ["reflector.pipelines.main_live_pipeline"] + + +@pytest.mark.usefixtures("setup_database") +@pytest.mark.usefixtures("celery_app") +@pytest.mark.usefixtures("celery_worker") +@pytest.mark.asyncio +async def test_transcript_upload_file( + tmpdir, + dummy_llm, + dummy_processors, + dummy_diarization, + dummy_storage, +): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + + # create a transcript + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["status"] == "idle" + tid = response.json()["id"] + + # upload mp3 + response = await ac.post( + f"/transcripts/{tid}/record/upload", + files={ + "file": ( + "test_mathieu_hello.mp3", + open("tests/records/test_mathieu_hello.mp3", "rb"), + "audio/mpeg", + ) + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + # wait the processing to finish + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] in ("ended", "error"): + break + await asyncio.sleep(1) + + # check the transcript is ended + transcript = resp.json() + assert transcript["status"] == "ended" + assert transcript["short_summary"] == "LLM SHORT SUMMARY" + assert transcript["title"] == "LLM TITLE" + + # check topics and transcript + response = await ac.get(f"/transcripts/{tid}/topics") + assert response.status_code == 200 + assert len(response.json()) == 1 + assert "want to share" in response.json()[0]["transcript"]