mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
Merge pull request #375 from Monadical-SAS/restart-processing
Restart processing
This commit is contained in:
@@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router
|
||||
from reflector.views.transcripts_participants import (
|
||||
router as transcripts_participants_router,
|
||||
)
|
||||
from reflector.views.transcripts_process import router as transcripts_process_router
|
||||
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
|
||||
from reflector.views.transcripts_upload import router as transcripts_upload_router
|
||||
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
||||
@@ -74,6 +75,7 @@ app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||
app.include_router(transcripts_process_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
|
||||
@@ -637,13 +637,21 @@ def pipeline_post(*, transcript_id: str):
|
||||
|
||||
|
||||
@get_transcript
|
||||
async def pipeline_upload(transcript: Transcript, logger: Logger):
|
||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||
import av
|
||||
|
||||
try:
|
||||
# open audio
|
||||
upload_filename = next(transcript.data_path.glob("upload.*"))
|
||||
container = av.open(upload_filename.as_posix())
|
||||
audio_filename = next(transcript.data_path.glob("upload.*"), None)
|
||||
if audio_filename and transcript.status != "uploaded":
|
||||
raise Exception("File upload is not completed")
|
||||
|
||||
if not audio_filename:
|
||||
audio_filename = next(transcript.data_path.glob("audio.*"), None)
|
||||
if not audio_filename:
|
||||
raise Exception("There is no file to process")
|
||||
|
||||
container = av.open(audio_filename.as_posix())
|
||||
|
||||
# create pipeline
|
||||
pipeline = PipelineMainLive(transcript_id=transcript.id)
|
||||
@@ -676,5 +684,5 @@ async def pipeline_upload(transcript: Transcript, logger: Logger):
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_upload(*, transcript_id: str):
|
||||
return await pipeline_upload(transcript_id=transcript_id)
|
||||
async def task_pipeline_process(*, transcript_id: str):
|
||||
return await pipeline_process(transcript_id=transcript_id)
|
||||
|
||||
55
server/reflector/views/transcripts_process.py
Normal file
55
server/reflector/views/transcripts_process.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import celery
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_process
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ProcessStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
@router.post("/transcripts/{transcript_id}/process")
|
||||
async def transcript_process(
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
if transcript.status == "idle":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Recording is not ready for processing"
|
||||
)
|
||||
|
||||
if task_is_scheduled_or_active(
|
||||
"reflector.pipelines.main_live_pipeline.task_pipeline_process",
|
||||
transcript_id=transcript_id,
|
||||
):
|
||||
return ProcessStatus(status="already running")
|
||||
|
||||
# schedule a background task process the file
|
||||
task_pipeline_process.delay(transcript_id=transcript_id)
|
||||
|
||||
return ProcessStatus(status="ok")
|
||||
|
||||
|
||||
def task_is_scheduled_or_active(task_name: str, **kwargs):
|
||||
inspect = celery.current_app.control.inspect()
|
||||
|
||||
for worker, tasks in (inspect.scheduled() | inspect.active()).items():
|
||||
for task in tasks:
|
||||
if task["name"] == task_name and task["kwargs"] == kwargs:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -5,7 +5,7 @@ import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from reflector.db.transcripts import transcripts_controller
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_process
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -91,6 +91,6 @@ async def transcript_record_upload(
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
# launch a background task to process the file
|
||||
task_pipeline_upload.delay(transcript_id=transcript_id)
|
||||
task_pipeline_process.delay(transcript_id=transcript_id)
|
||||
|
||||
return UploadStatus(status="ok")
|
||||
|
||||
78
server/tests/test_transcripts_process.py
Normal file
78
server/tests/test_transcripts_process.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_database")
|
||||
@pytest.mark.usefixtures("celery_session_app")
|
||||
@pytest.mark.usefixtures("celery_session_worker")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_process(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
):
|
||||
from reflector.app import app
|
||||
|
||||
ac = AsyncClient(app=app, base_url="http://test/v1")
|
||||
|
||||
# create a transcript
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "idle"
|
||||
tid = response.json()["id"]
|
||||
|
||||
# upload mp3
|
||||
response = await ac.post(
|
||||
f"/transcripts/{tid}/record/upload",
|
||||
files={
|
||||
"file": (
|
||||
"test_short.wav",
|
||||
open("tests/records/test_short.wav", "rb"),
|
||||
"audio/mpeg",
|
||||
),
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# restart the processing
|
||||
response = await ac.post(
|
||||
f"/transcripts/{tid}/process",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
# wait for processing to finish
|
||||
while True:
|
||||
# fetch the transcript and check if it is ended
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
if resp.json()["status"] in ("ended", "error"):
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# check the transcript is ended
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert "want to share" in response.json()[0]["transcript"]
|
||||
Reference in New Issue
Block a user