mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
Merge branch 'main' into feat-api-speaker-reassignment
This commit is contained in:
2
server/.gitignore
vendored
2
server/.gitignore
vendored
@@ -178,3 +178,5 @@ audio_*.wav
|
||||
# ignore local database
|
||||
reflector.sqlite3
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
18
server/poetry.lock
generated
18
server/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aioboto3"
|
||||
@@ -2881,6 +2881,20 @@ cryptography = ["cryptography (>=3.4.0)"]
|
||||
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
|
||||
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.6"
|
||||
description = "A streaming multipart parser for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"},
|
||||
{file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.1"
|
||||
@@ -4219,4 +4233,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "91d85539f5093abad70e34aa4d533272d6a2e2bbdb539c7968fe79c28b50d01a"
|
||||
content-hash = "b823010302af2dcd2ece591eaf10d5cbf945f74bd0fc35fc69f3060c0c253d57"
|
||||
|
||||
@@ -36,6 +36,7 @@ profanityfilter = "^2.0.6"
|
||||
celery = "^5.3.4"
|
||||
redis = "^5.0.1"
|
||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||
python-multipart = "^0.0.6"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -18,6 +18,7 @@ from reflector.views.transcripts_participants import (
|
||||
router as transcripts_participants_router,
|
||||
)
|
||||
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
|
||||
from reflector.views.transcripts_upload import router as transcripts_upload_router
|
||||
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
||||
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
||||
from reflector.views.user import router as user_router
|
||||
@@ -70,6 +71,7 @@ app.include_router(transcripts_router, prefix="/v1")
|
||||
app.include_router(transcripts_audio_router, prefix="/v1")
|
||||
app.include_router(transcripts_participants_router, prefix="/v1")
|
||||
app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
|
||||
@@ -474,6 +474,15 @@ class PipelineMainWaveform(PipelineMainFromTopics):
|
||||
]
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting remove upload")
|
||||
uploads = transcript.data_path.glob("upload.*")
|
||||
for upload in uploads:
|
||||
upload.unlink(missing_ok=True)
|
||||
logger.info("Remove upload done")
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_waveform(transcript: Transcript, logger: Logger):
|
||||
logger.info("Starting waveform")
|
||||
@@ -560,6 +569,12 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_remove_upload(*, transcript_id: str):
|
||||
await pipeline_remove_upload(transcript_id=transcript_id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_waveform(*, transcript_id: str):
|
||||
@@ -604,6 +619,7 @@ def pipeline_post(*, transcript_id: str):
|
||||
task_pipeline_waveform.si(transcript_id=transcript_id)
|
||||
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
|
||||
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
|
||||
| task_pipeline_diarization.si(transcript_id=transcript_id)
|
||||
)
|
||||
chain_title_preview = task_pipeline_title_and_short_summary.si(
|
||||
@@ -618,3 +634,47 @@ def pipeline_post(*, transcript_id: str):
|
||||
chain_final_summaries,
|
||||
)
|
||||
chain.delay()
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload(transcript: Transcript, logger: Logger):
|
||||
import av
|
||||
|
||||
try:
|
||||
# open audio
|
||||
upload_filename = next(transcript.data_path.glob("upload.*"))
|
||||
container = av.open(upload_filename.as_posix())
|
||||
|
||||
# create pipeline
|
||||
pipeline = PipelineMainLive(transcript_id=transcript.id)
|
||||
pipeline.start()
|
||||
|
||||
# push audio to pipeline
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
pipeline.flush()
|
||||
|
||||
logger.info("Waiting for the pipeline to end")
|
||||
await pipeline.join()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Pipeline error", exc_info=exc)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"status": "error",
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info("Pipeline ended")
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload(*, transcript_id: str):
|
||||
return await pipeline_upload(transcript_id=transcript_id)
|
||||
|
||||
@@ -30,7 +30,8 @@ class PipelineRunner(BaseModel):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._q_cmd = asyncio.Queue()
|
||||
self._task = None
|
||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||
self._ev_done = asyncio.Event()
|
||||
self._is_first_push = True
|
||||
self._logger = logger.bind(
|
||||
@@ -49,7 +50,14 @@ class PipelineRunner(BaseModel):
|
||||
"""
|
||||
Start the pipeline as a coroutine task
|
||||
"""
|
||||
asyncio.get_event_loop().create_task(self.run())
|
||||
self._task = asyncio.get_event_loop().create_task(self.run())
|
||||
|
||||
async def join(self):
|
||||
"""
|
||||
Wait for the pipeline to finish
|
||||
"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
def start_sync(self):
|
||||
"""
|
||||
|
||||
79
server/reflector/views/transcripts_upload.py
Normal file
79
server/reflector/views/transcripts_upload.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import av
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UploadStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/record/upload")
|
||||
async def transcript_record_upload(
|
||||
transcript_id: str,
|
||||
file: UploadFile,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
# ensure there is no other upload in the directory (searching data_path/upload.*)
|
||||
if any(transcript.data_path.glob("upload.*")):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="There is already an upload in progress"
|
||||
)
|
||||
|
||||
# save the file to the transcript folder
|
||||
extension = file.filename.split(".")[-1]
|
||||
upload_filename = transcript.data_path / f"upload.{extension}"
|
||||
upload_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ensure the file is back to the beginning
|
||||
await file.seek(0)
|
||||
|
||||
# save the file to the transcript folder
|
||||
try:
|
||||
with open(upload_filename, "wb") as f:
|
||||
while True:
|
||||
chunk = await file.read(16384)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception:
|
||||
upload_filename.unlink()
|
||||
raise
|
||||
|
||||
# ensure the file have audio part, using av
|
||||
# XXX Trying to do this check on the initial UploadFile object is not
|
||||
# possible, dunno why. UploadFile.file has no name.
|
||||
# Trying to pass UploadFile.file with format=extension does not work
|
||||
# it never detect audio stream...
|
||||
container = av.open(upload_filename.as_posix())
|
||||
try:
|
||||
if not len(container.streams.audio):
|
||||
raise HTTPException(status_code=400, detail="File has no audio stream")
|
||||
except Exception:
|
||||
# delete the uploaded file
|
||||
upload_filename.unlink()
|
||||
raise
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
# set the status to "uploaded"
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
# launch a background task to process the file
|
||||
task_pipeline_upload.delay(transcript_id=transcript_id)
|
||||
|
||||
return UploadStatus(status="ok")
|
||||
@@ -1,27 +1,32 @@
|
||||
import celery
|
||||
import structlog
|
||||
from celery import Celery
|
||||
from reflector.settings import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
app = Celery(__name__)
|
||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
app.conf.broker_connection_retry_on_startup = True
|
||||
app.autodiscover_tasks(
|
||||
[
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.worker.healthcheck",
|
||||
]
|
||||
)
|
||||
|
||||
# crontab
|
||||
app.conf.beat_schedule = {}
|
||||
|
||||
if settings.HEALTHCHECK_URL:
|
||||
app.conf.beat_schedule["healthcheck_ping"] = {
|
||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||
"schedule": 60.0 * 10,
|
||||
}
|
||||
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
|
||||
if celery.current_app.main != "default":
|
||||
logger.info(f"Celery already configured ({celery.current_app})")
|
||||
app = celery.current_app
|
||||
else:
|
||||
logger.warning("Healthcheck disabled, no url configured")
|
||||
app = Celery(__name__)
|
||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
app.conf.broker_connection_retry_on_startup = True
|
||||
app.autodiscover_tasks(
|
||||
[
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.worker.healthcheck",
|
||||
]
|
||||
)
|
||||
|
||||
# crontab
|
||||
app.conf.beat_schedule = {}
|
||||
|
||||
if settings.HEALTHCHECK_URL:
|
||||
app.conf.beat_schedule["healthcheck_ping"] = {
|
||||
"task": "reflector.worker.healthcheck.healthcheck_ping",
|
||||
"schedule": 60.0 * 10,
|
||||
}
|
||||
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
|
||||
else:
|
||||
logger.warning("Healthcheck disabled, no url configured")
|
||||
|
||||
@@ -165,6 +165,11 @@ def celery_config():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def celery_includes():
|
||||
return ["reflector.pipelines.main_live_pipeline"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fake_mp3_upload():
|
||||
with patch(
|
||||
|
||||
@@ -32,7 +32,7 @@ class ThreadedUvicorn:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def appserver(tmpdir, celery_session_app, celery_session_worker):
|
||||
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
|
||||
from reflector.settings import settings
|
||||
from reflector.app import app
|
||||
|
||||
@@ -57,6 +57,7 @@ def celery_includes():
|
||||
return ["reflector.pipelines.main_live_pipeline"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
@@ -213,6 +214,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
|
||||
61
server/tests/test_transcripts_upload.py
Normal file
61
server/tests/test_transcripts_upload.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_upload_file(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
|
||||
# create a transcript
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "idle"
|
||||
tid = response.json()["id"]
|
||||
|
||||
# upload mp3
|
||||
response = await ac.post(
|
||||
f"/transcripts/{tid}/record/upload",
|
||||
files={
|
||||
"file": (
|
||||
"test_short.wav",
|
||||
open("tests/records/test_short.wav", "rb"),
|
||||
"audio/mpeg",
|
||||
)
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait the processing to finish
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert "want to share" in response.json()[0]["transcript"]
|
||||
Reference in New Issue
Block a user