mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: include endpoint to upload a audio/video file
This commit is contained in:
18
server/poetry.lock
generated
18
server/poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aioboto3"
|
name = "aioboto3"
|
||||||
@@ -2881,6 +2881,20 @@ cryptography = ["cryptography (>=3.4.0)"]
|
|||||||
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
|
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
|
||||||
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
|
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-multipart"
|
||||||
|
version = "0.0.6"
|
||||||
|
description = "A streaming multipart parser for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"},
|
||||||
|
{file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyyaml"
|
name = "pyyaml"
|
||||||
version = "6.0.1"
|
version = "6.0.1"
|
||||||
@@ -4219,4 +4233,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a"
|
content-hash = "b823010302af2dcd2ece591eaf10d5cbf945f74bd0fc35fc69f3060c0c253d57"
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ profanityfilter = "^2.0.6"
|
|||||||
celery = "^5.3.4"
|
celery = "^5.3.4"
|
||||||
redis = "^5.0.1"
|
redis = "^5.0.1"
|
||||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||||
|
python-multipart = "^0.0.6"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router
|
|||||||
from reflector.views.transcripts_participants import (
|
from reflector.views.transcripts_participants import (
|
||||||
router as transcripts_participants_router,
|
router as transcripts_participants_router,
|
||||||
)
|
)
|
||||||
|
from reflector.views.transcripts_upload import router as transcripts_upload_router
|
||||||
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
||||||
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
||||||
from reflector.views.user import router as user_router
|
from reflector.views.user import router as user_router
|
||||||
@@ -68,6 +69,7 @@ app.include_router(rtc_offer_router)
|
|||||||
app.include_router(transcripts_router, prefix="/v1")
|
app.include_router(transcripts_router, prefix="/v1")
|
||||||
app.include_router(transcripts_audio_router, prefix="/v1")
|
app.include_router(transcripts_audio_router, prefix="/v1")
|
||||||
app.include_router(transcripts_participants_router, prefix="/v1")
|
app.include_router(transcripts_participants_router, prefix="/v1")
|
||||||
|
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||||
app.include_router(user_router, prefix="/v1")
|
app.include_router(user_router, prefix="/v1")
|
||||||
|
|||||||
@@ -618,3 +618,47 @@ def pipeline_post(*, transcript_id: str):
|
|||||||
chain_final_summaries,
|
chain_final_summaries,
|
||||||
)
|
)
|
||||||
chain.delay()
|
chain.delay()
|
||||||
|
|
||||||
|
|
||||||
|
@get_transcript
|
||||||
|
async def pipeline_upload(transcript: Transcript, logger: Logger):
|
||||||
|
import av
|
||||||
|
|
||||||
|
try:
|
||||||
|
# open audio
|
||||||
|
upload_filename = next(transcript.data_path.glob("upload.*"))
|
||||||
|
container = av.open(upload_filename.as_posix())
|
||||||
|
|
||||||
|
# create pipeline
|
||||||
|
pipeline = PipelineMainLive(transcript_id=transcript.id)
|
||||||
|
pipeline.start()
|
||||||
|
|
||||||
|
# push audio to pipeline
|
||||||
|
try:
|
||||||
|
logger.info("Start pushing audio into the pipeline")
|
||||||
|
for frame in container.decode(audio=0):
|
||||||
|
pipeline.push(frame)
|
||||||
|
finally:
|
||||||
|
logger.info("Flushing the pipeline")
|
||||||
|
pipeline.flush()
|
||||||
|
|
||||||
|
logger.info("Waiting for the pipeline to end")
|
||||||
|
await pipeline.join()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Pipeline error", exc_info=exc)
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info("Pipeline ended")
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
@asynctask
|
||||||
|
async def task_pipeline_upload(*, transcript_id: str):
|
||||||
|
return await pipeline_upload(transcript_id=transcript_id)
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ class PipelineRunner(BaseModel):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._q_cmd = asyncio.Queue()
|
self._task = None
|
||||||
|
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||||
self._ev_done = asyncio.Event()
|
self._ev_done = asyncio.Event()
|
||||||
self._is_first_push = True
|
self._is_first_push = True
|
||||||
self._logger = logger.bind(
|
self._logger = logger.bind(
|
||||||
@@ -49,7 +50,14 @@ class PipelineRunner(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Start the pipeline as a coroutine task
|
Start the pipeline as a coroutine task
|
||||||
"""
|
"""
|
||||||
asyncio.get_event_loop().create_task(self.run())
|
self._task = asyncio.get_event_loop().create_task(self.run())
|
||||||
|
|
||||||
|
async def join(self):
|
||||||
|
"""
|
||||||
|
Wait for the pipeline to finish
|
||||||
|
"""
|
||||||
|
if self._task:
|
||||||
|
await self._task
|
||||||
|
|
||||||
def start_sync(self):
|
def start_sync(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
79
server/reflector/views/transcripts_upload.py
Normal file
79
server/reflector/views/transcripts_upload.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
import reflector.auth as auth
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class UploadStatus(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/transcripts/{transcript_id}/record/upload")
|
||||||
|
async def transcript_record_upload(
|
||||||
|
transcript_id: str,
|
||||||
|
file: UploadFile,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
):
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if transcript.locked:
|
||||||
|
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||||
|
|
||||||
|
# ensure there is no other upload in the directory (searching data_path/upload.*)
|
||||||
|
if any(transcript.data_path.glob("upload.*")):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="There is already an upload in progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
# save the file to the transcript folder
|
||||||
|
extension = file.filename.split(".")[-1]
|
||||||
|
upload_filename = transcript.data_path / f"upload.{extension}"
|
||||||
|
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ensure the file is back to the beginning
|
||||||
|
await file.seek(0)
|
||||||
|
|
||||||
|
# save the file to the transcript folder
|
||||||
|
try:
|
||||||
|
with open(upload_filename, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = await file.read(16384)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
except Exception:
|
||||||
|
upload_filename.unlink()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ensure the file have audio part, using av
|
||||||
|
# XXX Trying to do this check on the initial UploadFile object is not
|
||||||
|
# possible, dunno why. UploadFile.file has no name.
|
||||||
|
# Trying to pass UploadFile.file with format=extension does not work
|
||||||
|
# it never detect audio stream...
|
||||||
|
container = av.open(upload_filename.as_posix())
|
||||||
|
try:
|
||||||
|
if not len(container.streams.audio):
|
||||||
|
raise HTTPException(status_code=400, detail="File has no audio stream")
|
||||||
|
except Exception:
|
||||||
|
# delete the uploaded file
|
||||||
|
upload_filename.unlink()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
# set the status to "uploaded"
|
||||||
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
|
# launch a background task to process the file
|
||||||
|
task_pipeline_upload.delay(transcript_id=transcript_id)
|
||||||
|
|
||||||
|
return UploadStatus(status="ok")
|
||||||
@@ -1,27 +1,32 @@
|
|||||||
|
import celery
|
||||||
import structlog
|
import structlog
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
logger = structlog.get_logger(__name__)
|
logger = structlog.get_logger(__name__)
|
||||||
app = Celery(__name__)
|
if celery.current_app is not None:
|
||||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
logger.info(f"Celery already configured ({celery.current_app})")
|
||||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
app = celery.current_app
|
||||||
app.conf.broker_connection_retry_on_startup = True
|
else:
|
||||||
app.autodiscover_tasks(
|
app = Celery(__name__)
|
||||||
|
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||||
|
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||||
|
app.conf.broker_connection_retry_on_startup = True
|
||||||
|
app.autodiscover_tasks(
|
||||||
[
|
[
|
||||||
"reflector.pipelines.main_live_pipeline",
|
"reflector.pipelines.main_live_pipeline",
|
||||||
"reflector.worker.healthcheck",
|
"reflector.worker.healthcheck",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# crontab
|
# crontab
|
||||||
app.conf.beat_schedule = {}
|
app.conf.beat_schedule = {}
|
||||||
|
|
||||||
if settings.HEALTHCHECK_URL:
|
if settings.HEALTHCHECK_URL:
|
||||||
app.conf.beat_schedule["healthcheck_ping"] = {
|
app.conf.beat_schedule["healthcheck_ping"] = {
|
||||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||||
"schedule": 60.0 * 10,
|
"schedule": 60.0 * 10,
|
||||||
}
|
}
|
||||||
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
|
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
|
||||||
else:
|
else:
|
||||||
logger.warning("Healthcheck disabled, no url configured")
|
logger.warning("Healthcheck disabled, no url configured")
|
||||||
|
|||||||
81
server/tests/test_transcripts_upload.py
Normal file
81
server/tests/test_transcripts_upload.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_enable_logging():
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_config():
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
with NamedTemporaryFile() as f:
|
||||||
|
yield {
|
||||||
|
"broker_url": "memory://",
|
||||||
|
"result_backend": f"db+sqlite:///{f.name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_includes():
|
||||||
|
return ["reflector.pipelines.main_live_pipeline"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_database")
|
||||||
|
@pytest.mark.usefixtures("celery_app")
|
||||||
|
@pytest.mark.usefixtures("celery_worker")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_upload_file(
|
||||||
|
tmpdir,
|
||||||
|
dummy_llm,
|
||||||
|
dummy_processors,
|
||||||
|
dummy_diarization,
|
||||||
|
dummy_storage,
|
||||||
|
):
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||||
|
|
||||||
|
# create a transcript
|
||||||
|
response = await ac.post("/transcripts", json={"name": "test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "idle"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
# upload mp3
|
||||||
|
response = await ac.post(
|
||||||
|
f"/transcripts/{tid}/record/upload",
|
||||||
|
files={
|
||||||
|
"file": (
|
||||||
|
"test_mathieu_hello.mp3",
|
||||||
|
open("tests/records/test_mathieu_hello.mp3", "rb"),
|
||||||
|
"audio/mpeg",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
|
# wait the processing to finish
|
||||||
|
while True:
|
||||||
|
# fetch the transcript and check if it is ended
|
||||||
|
resp = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
if resp.json()["status"] in ("ended", "error"):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# check the transcript is ended
|
||||||
|
transcript = resp.json()
|
||||||
|
assert transcript["status"] == "ended"
|
||||||
|
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||||
|
assert transcript["title"] == "LLM TITLE"
|
||||||
|
|
||||||
|
# check topics and transcript
|
||||||
|
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert len(response.json()) == 1
|
||||||
|
assert "want to share" in response.json()[0]["transcript"]
|
||||||
Reference in New Issue
Block a user