diff --git a/server/reflector/app.py b/server/reflector/app.py index 1be71210..c952ebae 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -17,6 +17,7 @@ from reflector.views.transcripts_audio import router as transcripts_audio_router from reflector.views.transcripts_participants import ( router as transcripts_participants_router, ) +from reflector.views.transcripts_process import router as transcripts_process_router from reflector.views.transcripts_speaker import router as transcripts_speaker_router from reflector.views.transcripts_upload import router as transcripts_upload_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router @@ -74,6 +75,7 @@ app.include_router(transcripts_speaker_router, prefix="/v1") app.include_router(transcripts_upload_router, prefix="/v1") app.include_router(transcripts_websocket_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1") +app.include_router(transcripts_process_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 2b87f23a..b1e0a2aa 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -637,13 +637,21 @@ def pipeline_post(*, transcript_id: str): @get_transcript -async def pipeline_upload(transcript: Transcript, logger: Logger): +async def pipeline_process(transcript: Transcript, logger: Logger): import av try: # open audio - upload_filename = next(transcript.data_path.glob("upload.*")) - container = av.open(upload_filename.as_posix()) + audio_filename = next(transcript.data_path.glob("upload.*"), None) + if audio_filename and transcript.status != "uploaded": + raise Exception("File upload is not completed") + + if not audio_filename: + audio_filename = next(transcript.data_path.glob("audio.*"), None) + if not audio_filename: + raise Exception("There is no file to process") + + container = av.open(audio_filename.as_posix()) # create pipeline pipeline = PipelineMainLive(transcript_id=transcript.id) @@ -676,5 +684,5 @@ async def pipeline_upload(transcript: Transcript, logger: Logger): @shared_task @asynctask -async def task_pipeline_upload(*, transcript_id: str): - return await pipeline_upload(transcript_id=transcript_id) +async def task_pipeline_process(*, transcript_id: str): + return await pipeline_process(transcript_id=transcript_id) diff --git a/server/reflector/views/transcripts_process.py b/server/reflector/views/transcripts_process.py new file mode 100644 index 00000000..a4f4a47f --- /dev/null +++ b/server/reflector/views/transcripts_process.py @@ -0,0 +1,55 @@ +from typing import Annotated, Optional + +import celery +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from reflector.db.transcripts import transcripts_controller +from reflector.pipelines.main_live_pipeline import task_pipeline_process + +router = APIRouter() + + +class ProcessStatus(BaseModel): + status: str + + +@router.post("/transcripts/{transcript_id}/process") +async def transcript_process( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.locked: + raise HTTPException(status_code=400, detail="Transcript is locked") + + if transcript.status == "idle": + raise HTTPException( + status_code=400, detail="Recording is not ready for processing" + ) + + if task_is_scheduled_or_active( + "reflector.pipelines.main_live_pipeline.task_pipeline_process", + transcript_id=transcript_id, + ): + return ProcessStatus(status="already running") + + # schedule a background task process the file + task_pipeline_process.delay(transcript_id=transcript_id) + + return ProcessStatus(status="ok") + + +def task_is_scheduled_or_active(task_name: str, **kwargs): + inspect = celery.current_app.control.inspect() + + for worker, tasks in (inspect.scheduled() | inspect.active()).items(): + for task in tasks: + if task["name"] == task_name and task["kwargs"] == kwargs: + return True + + return False diff --git a/server/reflector/views/transcripts_upload.py b/server/reflector/views/transcripts_upload.py index 4fa45e3e..24355570 100644 --- a/server/reflector/views/transcripts_upload.py +++ b/server/reflector/views/transcripts_upload.py @@ -5,7 +5,7 @@ import reflector.auth as auth from fastapi import APIRouter, Depends, HTTPException, UploadFile from pydantic import BaseModel from reflector.db.transcripts import transcripts_controller -from reflector.pipelines.main_live_pipeline import task_pipeline_upload +from reflector.pipelines.main_live_pipeline import task_pipeline_process router = APIRouter() @@ -91,6 +91,6 @@ async def transcript_record_upload( await transcripts_controller.update(transcript, {"status": "uploaded"}) # launch a background task to process the file - task_pipeline_upload.delay(transcript_id=transcript_id) + task_pipeline_process.delay(transcript_id=transcript_id) return UploadStatus(status="ok") diff --git a/server/tests/test_transcripts_process.py b/server/tests/test_transcripts_process.py new file mode 100644 index 00000000..75531d04 --- /dev/null +++ b/server/tests/test_transcripts_process.py @@ -0,0 +1,78 @@ +import asyncio + +import pytest +from httpx import AsyncClient + + +@pytest.mark.usefixtures("setup_database") +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") +@pytest.mark.asyncio +async def test_transcript_process( + tmpdir, + ensure_casing, + dummy_llm, + dummy_processors, + dummy_diarization, + dummy_storage, +): + from reflector.app import app + + ac = AsyncClient(app=app, base_url="http://test/v1") + + # create a transcript + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["status"] == "idle" + tid = response.json()["id"] + + # upload mp3 + response = await ac.post( + f"/transcripts/{tid}/record/upload", + files={ + "file": ( + "test_short.wav", + open("tests/records/test_short.wav", "rb"), + "audio/mpeg", + ), + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + # wait for processing to finish + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] in ("ended", "error"): + break + await asyncio.sleep(1) + + # restart the processing + response = await ac.post( + f"/transcripts/{tid}/process", + ) + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + # wait for processing to finish + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] in ("ended", "error"): + break + await asyncio.sleep(1) + + # check the transcript is ended + transcript = resp.json() + assert transcript["status"] == "ended" + assert transcript["short_summary"] == "LLM SHORT SUMMARY" + assert transcript["title"] == "LLM TITLE" + + # check topics and transcript + response = await ac.get(f"/transcripts/{tid}/topics") + assert response.status_code == 200 + assert len(response.json()) == 1 + assert "want to share" in response.json()[0]["transcript"] diff --git a/www/app/[domain]/browse/page.tsx b/www/app/[domain]/browse/page.tsx index 7cbb3d19..3c2529fe 100644 --- a/www/app/[domain]/browse/page.tsx +++ b/www/app/[domain]/browse/page.tsx @@ -4,7 +4,7 @@ import React, { useEffect, useState } from "react"; import { GetTranscript } from "../../api"; import Pagination from "./pagination"; import NextLink from "next/link"; -import { FaGear } from "react-icons/fa6"; +import { FaArrowRotateRight, FaGear } from "react-icons/fa6"; import { FaCheck, FaTrash, FaStar, FaMicrophone } from "react-icons/fa"; import { MdError } from "react-icons/md"; import useTranscriptList from "../transcripts/useTranscriptList"; @@ -20,20 +20,10 @@ import { Card, Link, CardBody, - CardFooter, Stack, Text, Icon, Grid, - Divider, - Popover, - PopoverTrigger, - PopoverContent, - PopoverArrow, - PopoverCloseButton, - PopoverHeader, - PopoverBody, - PopoverFooter, IconButton, Spacer, Menu, @@ -46,7 +36,6 @@ import { AlertDialogHeader, AlertDialogBody, AlertDialogFooter, - keyframes, Tooltip, } from "@chakra-ui/react"; import { PlusSquareIcon } from "@chakra-ui/icons"; @@ -93,12 +82,12 @@ export default function TranscriptBrowser() { ); const onCloseDeletion = () => setTranscriptToDeleteId(undefined); - const handleDeleteTranscript = (transcriptToDeleteId) => (e) => { + const handleDeleteTranscript = (transcriptId) => (e) => { e.stopPropagation(); if (api && !deletionLoading) { setDeletionLoading(true); api - .v1TranscriptDelete(transcriptToDeleteId) + .v1TranscriptDelete({ transcriptId }) .then(() => { refetch(); setDeletionLoading(false); @@ -106,7 +95,7 @@ export default function TranscriptBrowser() { onCloseDeletion(); setDeletedItemIds((deletedItemIds) => [ deletedItemIds, - ...transcriptToDeleteId, + ...transcriptId, ]); }) .catch((err) => { @@ -116,6 +105,24 @@ export default function TranscriptBrowser() { } }; + const handleProcessTranscript = (transcriptId) => (e) => { + if (api) { + api + .v1TranscriptProcess({ transcriptId }) + .then((result) => { + const status = (result as any).status; + if (status === "already running") { + setError( + new Error("Processing is already running, please wait"), + "Processing is already running, please wait", + ); + } + }) + .catch((err) => { + setError(err, "There was an error processing the transcript"); + }); + } + }; return ( - + } @@ -229,12 +236,19 @@ export default function TranscriptBrowser() { /> setTranscriptToDeleteId(item.id)} icon={} > Delete + } + > + Process + { diff --git a/www/app/api/services.gen.ts b/www/app/api/services.gen.ts index e5a18eb8..c1690b81 100644 --- a/www/app/api/services.gen.ts +++ b/www/app/api/services.gen.ts @@ -46,6 +46,8 @@ import type { V1TranscriptGetWebsocketEventsResponse, V1TranscriptRecordWebrtcData, V1TranscriptRecordWebrtcResponse, + V1TranscriptProcessData, + V1TranscriptProcessResponse, V1UserMeResponse, } from "./types.gen"; @@ -571,6 +573,28 @@ export class DefaultService { }); } + /** + * Transcript Process + * @param data The data for the request. + * @param data.transcriptId + * @returns unknown Successful Response + * @throws ApiError + */ + public v1TranscriptProcess( + data: V1TranscriptProcessData, + ): CancelablePromise { + return this.httpRequest.request({ + method: "POST", + url: "/v1/transcripts/{transcript_id}/process", + path: { + transcript_id: data.transcriptId, + }, + errors: { + 422: "Validation Error", + }, + }); + } + /** * User Me * @returns unknown Successful Response diff --git a/www/app/api/types.gen.ts b/www/app/api/types.gen.ts index bc06d3c1..bb811560 100644 --- a/www/app/api/types.gen.ts +++ b/www/app/api/types.gen.ts @@ -317,6 +317,12 @@ export type V1TranscriptRecordWebrtcData = { export type V1TranscriptRecordWebrtcResponse = unknown; +export type V1TranscriptProcessData = { + transcriptId: string; +}; + +export type V1TranscriptProcessResponse = unknown; + export type V1UserMeResponse = UserInfo | null; export type $OpenApiTs = { @@ -631,6 +637,21 @@ export type $OpenApiTs = { }; }; }; + "/v1/transcripts/{transcript_id}/process": { + post: { + req: V1TranscriptProcessData; + res: { + /** + * Successful Response + */ + 200: unknown; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/v1/me": { get: { res: {