Merge pull request #375 from Monadical-SAS/restart-processing

Restart processing
This commit is contained in:
2024-07-19 12:00:30 +02:00
committed by GitHub
9 changed files with 227 additions and 25 deletions

View File

@@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_participants import (
router as transcripts_participants_router,
)
from reflector.views.transcripts_process import router as transcripts_process_router
from reflector.views.transcripts_speaker import router as transcripts_speaker_router
from reflector.views.transcripts_upload import router as transcripts_upload_router
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
@@ -74,6 +75,7 @@ app.include_router(transcripts_speaker_router, prefix="/v1")
app.include_router(transcripts_upload_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)

View File

@@ -637,13 +637,21 @@ def pipeline_post(*, transcript_id: str):
@get_transcript
async def pipeline_upload(transcript: Transcript, logger: Logger):
async def pipeline_process(transcript: Transcript, logger: Logger):
import av
try:
# open audio
upload_filename = next(transcript.data_path.glob("upload.*"))
container = av.open(upload_filename.as_posix())
audio_filename = next(transcript.data_path.glob("upload.*"), None)
if audio_filename and transcript.status != "uploaded":
raise Exception("File upload is not completed")
if not audio_filename:
audio_filename = next(transcript.data_path.glob("audio.*"), None)
if not audio_filename:
raise Exception("There is no file to process")
container = av.open(audio_filename.as_posix())
# create pipeline
pipeline = PipelineMainLive(transcript_id=transcript.id)
@@ -676,5 +684,5 @@ async def pipeline_upload(transcript: Transcript, logger: Logger):
@shared_task
@asynctask
async def task_pipeline_upload(*, transcript_id: str):
return await pipeline_upload(transcript_id=transcript_id)
async def task_pipeline_process(*, transcript_id: str):
return await pipeline_process(transcript_id=transcript_id)

View File

@@ -0,0 +1,55 @@
from typing import Annotated, Optional
import celery
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import task_pipeline_process
router = APIRouter()
class ProcessStatus(BaseModel):
status: str
@router.post("/transcripts/{transcript_id}/process")
async def transcript_process(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
if transcript.status == "idle":
raise HTTPException(
status_code=400, detail="Recording is not ready for processing"
)
if task_is_scheduled_or_active(
"reflector.pipelines.main_live_pipeline.task_pipeline_process",
transcript_id=transcript_id,
):
return ProcessStatus(status="already running")
# schedule a background task process the file
task_pipeline_process.delay(transcript_id=transcript_id)
return ProcessStatus(status="ok")
def task_is_scheduled_or_active(task_name: str, **kwargs):
inspect = celery.current_app.control.inspect()
for worker, tasks in (inspect.scheduled() | inspect.active()).items():
for task in tasks:
if task["name"] == task_name and task["kwargs"] == kwargs:
return True
return False

View File

@@ -5,7 +5,7 @@ import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import task_pipeline_upload
from reflector.pipelines.main_live_pipeline import task_pipeline_process
router = APIRouter()
@@ -91,6 +91,6 @@ async def transcript_record_upload(
await transcripts_controller.update(transcript, {"status": "uploaded"})
# launch a background task to process the file
task_pipeline_upload.delay(transcript_id=transcript_id)
task_pipeline_process.delay(transcript_id=transcript_id)
return UploadStatus(status="ok")

View File

@@ -0,0 +1,78 @@
import asyncio
import pytest
from httpx import AsyncClient
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
ensure_casing,
dummy_llm,
dummy_processors,
dummy_diarization,
dummy_storage,
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["status"] == "idle"
tid = response.json()["id"]
# upload mp3
response = await ac.post(
f"/transcripts/{tid}/record/upload",
files={
"file": (
"test_short.wav",
open("tests/records/test_short.wav", "rb"),
"audio/mpeg",
),
},
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
# restart the processing
response = await ac.post(
f"/transcripts/{tid}/process",
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "LLM TITLE"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"]

View File

@@ -4,7 +4,7 @@ import React, { useEffect, useState } from "react";
import { GetTranscript } from "../../api";
import Pagination from "./pagination";
import NextLink from "next/link";
import { FaGear } from "react-icons/fa6";
import { FaArrowRotateRight, FaGear } from "react-icons/fa6";
import { FaCheck, FaTrash, FaStar, FaMicrophone } from "react-icons/fa";
import { MdError } from "react-icons/md";
import useTranscriptList from "../transcripts/useTranscriptList";
@@ -20,20 +20,10 @@ import {
Card,
Link,
CardBody,
CardFooter,
Stack,
Text,
Icon,
Grid,
Divider,
Popover,
PopoverTrigger,
PopoverContent,
PopoverArrow,
PopoverCloseButton,
PopoverHeader,
PopoverBody,
PopoverFooter,
IconButton,
Spacer,
Menu,
@@ -46,7 +36,6 @@ import {
AlertDialogHeader,
AlertDialogBody,
AlertDialogFooter,
keyframes,
Tooltip,
} from "@chakra-ui/react";
import { PlusSquareIcon } from "@chakra-ui/icons";
@@ -93,12 +82,12 @@ export default function TranscriptBrowser() {
);
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
const handleDeleteTranscript = (transcriptToDeleteId) => (e) => {
const handleDeleteTranscript = (transcriptId) => (e) => {
e.stopPropagation();
if (api && !deletionLoading) {
setDeletionLoading(true);
api
.v1TranscriptDelete(transcriptToDeleteId)
.v1TranscriptDelete({ transcriptId })
.then(() => {
refetch();
setDeletionLoading(false);
@@ -106,7 +95,7 @@ export default function TranscriptBrowser() {
onCloseDeletion();
setDeletedItemIds((deletedItemIds) => [
deletedItemIds,
...transcriptToDeleteId,
...transcriptId,
]);
})
.catch((err) => {
@@ -116,6 +105,24 @@ export default function TranscriptBrowser() {
}
};
const handleProcessTranscript = (transcriptId) => (e) => {
if (api) {
api
.v1TranscriptProcess({ transcriptId })
.then((result) => {
const status = (result as any).status;
if (status === "already running") {
setError(
new Error("Processing is already running, please wait"),
"Processing is already running, please wait",
);
}
})
.catch((err) => {
setError(err, "There was an error processing the transcript");
});
}
};
return (
<Flex
maxW="container.xl"
@@ -221,7 +228,7 @@ export default function TranscriptBrowser() {
</Heading>
<Spacer />
<Menu closeOnSelect={false}>
<Menu closeOnSelect={true}>
<MenuButton
as={IconButton}
icon={<FaEllipsisVertical />}
@@ -229,12 +236,19 @@ export default function TranscriptBrowser() {
/>
<MenuList>
<MenuItem
disabled={deletionLoading}
isDisabled={deletionLoading}
onClick={() => setTranscriptToDeleteId(item.id)}
icon={<FaTrash color={"red.500"} />}
>
Delete
</MenuItem>
<MenuItem
isDisabled={item.status === "idle"}
onClick={handleProcessTranscript(item.id)}
icon={<FaArrowRotateRight />}
>
Process
</MenuItem>
<AlertDialog
isOpen={transcriptToDeleteId === item.id}
leastDestructiveRef={cancelRef}

View File

@@ -148,7 +148,7 @@ const TranscriptCreate = () => {
<Button
colorScheme="blue"
onClick={uploadFile}
isDisabled={!permissionOk || loadingRecord || loadingUpload}
isDisabled={loadingRecord || loadingUpload}
>
{loadingUpload ? "Loading..." : "Upload File"}
</Button>

View File

@@ -46,6 +46,8 @@ import type {
V1TranscriptGetWebsocketEventsResponse,
V1TranscriptRecordWebrtcData,
V1TranscriptRecordWebrtcResponse,
V1TranscriptProcessData,
V1TranscriptProcessResponse,
V1UserMeResponse,
} from "./types.gen";
@@ -571,6 +573,28 @@ export class DefaultService {
});
}
/**
* Transcript Process
* @param data The data for the request.
* @param data.transcriptId
* @returns unknown Successful Response
* @throws ApiError
*/
public v1TranscriptProcess(
data: V1TranscriptProcessData,
): CancelablePromise<V1TranscriptProcessResponse> {
return this.httpRequest.request({
method: "POST",
url: "/v1/transcripts/{transcript_id}/process",
path: {
transcript_id: data.transcriptId,
},
errors: {
422: "Validation Error",
},
});
}
/**
* User Me
* @returns unknown Successful Response

View File

@@ -317,6 +317,12 @@ export type V1TranscriptRecordWebrtcData = {
export type V1TranscriptRecordWebrtcResponse = unknown;
export type V1TranscriptProcessData = {
transcriptId: string;
};
export type V1TranscriptProcessResponse = unknown;
export type V1UserMeResponse = UserInfo | null;
export type $OpenApiTs = {
@@ -631,6 +637,21 @@ export type $OpenApiTs = {
};
};
};
"/v1/transcripts/{transcript_id}/process": {
post: {
req: V1TranscriptProcessData;
res: {
/**
* Successful Response
*/
200: unknown;
/**
* Validation Error
*/
422: HTTPValidationError;
};
};
};
"/v1/me": {
get: {
res: {