From 226b92c3474e05787643a77f2efde7abe132a346 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 7 Nov 2023 12:39:48 +0100 Subject: [PATCH 01/27] www/server: introduce share mode --- .../versions/0fea6d96b096_add_share_mode.py | 30 +++++++ server/reflector/db/transcripts.py | 55 +++++++++++- server/reflector/views/transcripts.py | 56 ++++++++----- .../transcripts/[transcriptId]/page.tsx | 7 +- www/app/[domain]/transcripts/shareLink.tsx | 84 ++++++++++++++++--- www/app/api/models/GetTranscript.ts | 17 ++++ www/app/api/models/UpdateTranscript.ts | 8 ++ www/app/styles/form.scss | 5 ++ 8 files changed, 228 insertions(+), 34 deletions(-) create mode 100644 server/migrations/versions/0fea6d96b096_add_share_mode.py diff --git a/server/migrations/versions/0fea6d96b096_add_share_mode.py b/server/migrations/versions/0fea6d96b096_add_share_mode.py new file mode 100644 index 00000000..52a72d48 --- /dev/null +++ b/server/migrations/versions/0fea6d96b096_add_share_mode.py @@ -0,0 +1,30 @@ +"""add share_mode + +Revision ID: 0fea6d96b096 +Revises: 38a927dcb099 +Create Date: 2023-11-07 11:12:21.614198 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0fea6d96b096' +down_revision: Union[str, None] = '38a927dcb099' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('transcript', sa.Column('share_mode', sa.String(), server_default='private', nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('transcript', 'share_mode') + # ### end Alembic commands ### diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 6ac2e32a..4b91423a 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -2,10 +2,11 @@ import json from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, Literal from uuid import uuid4 import sqlalchemy +from fastapi import HTTPException from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord @@ -30,6 +31,12 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), + sqlalchemy.Column( + "share_mode", + sqlalchemy.String, + nullable=False, + server_default="private", + ), ) @@ -99,6 +106,7 @@ class Transcript(BaseModel): events: list[TranscriptEvent] = [] source_language: str = "en" target_language: str = "en" + share_mode: Literal["private", "semi-private", "public"] = "private" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -169,6 +177,7 @@ class TranscriptController: order_by: str | None = None, filter_empty: bool | None = False, filter_recording: bool | None = False, + return_query: bool = False, ) -> list[Transcript]: """ Get all transcripts @@ -195,6 +204,9 @@ class TranscriptController: if filter_recording: query = query.filter(transcripts.c.status != "recording") + if return_query: + return query + results = await database.fetch_all(query) return results @@ -210,6 +222,47 @@ class TranscriptController: return None return Transcript(**result) + async def get_by_id_for_http( + self, + transcript_id: str, + user_id: str | None, + ) -> Transcript: + """ + Get a transcript by ID for HTTP request. + + If not found, it will raise a 404 error. + If the user is not allowed to access the transcript, it will raise a 403 error. + + This method checks the share mode of the transcript and the user_id + to determine if the user can access the transcript. + """ + query = transcripts.select().where(transcripts.c.id == transcript_id) + result = await database.fetch_one(query) + if not result: + raise HTTPException(status_code=404, detail="Transcript not found") + + # if the transcript is anonymous, share mode is not checked + transcript = Transcript(**result) + if transcript.user_id is None: + return transcript + + if transcript.share_mode == "private": + # in private mode, only the owner can access the transcript + if transcript.user_id == user_id: + return transcript + + elif transcript.share_mode == "semi-private": + # in semi-private mode, only the owner and the users with the link + # can access the transcript + if user_id is not None: + return transcript + + elif transcript.share_mode == "public": + # in public mode, everyone can access the transcript + return transcript + + raise HTTPException(status_code=403, detail="Transcript access denied") + async def add( self, name: str, diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index e3668ecb..5f1d7831 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Annotated, Optional +from typing import Annotated, Literal, Optional import reflector.auth as auth from fastapi import ( @@ -11,7 +11,8 @@ from fastapi import ( WebSocketDisconnect, status, ) -from fastapi_pagination import Page, paginate +from fastapi_pagination import Page +from fastapi_pagination.ext.databases import paginate from jose import jwt from pydantic import BaseModel, Field from reflector.db.transcripts import ( @@ -48,6 +49,7 @@ def create_access_token(data: dict, expires_delta: timedelta): class GetTranscript(BaseModel): id: str + user_id: str | None name: str status: str locked: bool @@ -56,6 +58,7 @@ class GetTranscript(BaseModel): short_summary: str | None long_summary: str | None created_at: datetime + share_mode: str = Field("private") source_language: str | None target_language: str | None @@ -72,6 +75,7 @@ class UpdateTranscript(BaseModel): title: Optional[str] = Field(None) short_summary: Optional[str] = Field(None) long_summary: Optional[str] = Field(None) + share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None) class DeletionStatus(BaseModel): @@ -82,12 +86,19 @@ class DeletionStatus(BaseModel): async def transcripts_list( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): + from reflector.db import database + if not user and not settings.PUBLIC_MODE: raise HTTPException(status_code=401, detail="Not authenticated") user_id = user["sub"] if user else None - return paginate( - await transcripts_controller.get_all(user_id=user_id, order_by="-created_at") + return await paginate( + database, + await transcripts_controller.get_all( + user_id=user_id, + order_by="-created_at", + return_query=True, + ), ) @@ -165,10 +176,9 @@ async def transcript_get( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - return transcript + return await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) @router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) @@ -192,6 +202,8 @@ async def transcript_update( values["short_summary"] = info.short_summary if info.title is not None: values["title"] = info.title + if info.share_mode is not None: + values["share_mode"] = info.share_mode await transcripts_controller.update(transcript, values) return transcript @@ -229,12 +241,12 @@ async def transcript_get_audio_mp3( except jwt.JWTError: raise unauthorized_exception - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") + raise HTTPException(status_code=500, detail="Audio not found") truncated_id = str(transcript.id).split("-")[0] filename = f"recording_{truncated_id}.mp3" @@ -253,12 +265,12 @@ async def transcript_get_audio_waveform( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ) -> AudioWaveform: user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") + raise HTTPException(status_code=500, detail="Audio not found") await run_in_threadpool(transcript.convert_audio_to_waveform) @@ -274,9 +286,9 @@ async def transcript_get_topics( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) # convert to GetTranscriptTopic return [ @@ -345,9 +357,9 @@ async def transcript_record_webrtc( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) if transcript.locked: raise HTTPException(status_code=400, detail="Transcript is locked") diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 9f9348c8..7b57ff2a 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -99,7 +99,12 @@ export default function TranscriptDetails(details: TranscriptDetails) { />
- +
diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx index 49163a5b..44c57053 100644 --- a/www/app/[domain]/transcripts/shareLink.tsx +++ b/www/app/[domain]/transcripts/shareLink.tsx @@ -1,15 +1,37 @@ import React, { useState, useRef, useEffect, use } from "react"; import { featureEnabled } from "../domainContext"; +import getApi from "../../lib/getApi"; +import { useFiefUserinfo } from "@fief/fief/nextjs/react"; +import SelectSearch from "react-select-search"; +import "react-select-search/style.css"; +import "../../styles/button.css"; +import "../../styles/form.scss"; -const ShareLink = () => { +type ShareLinkProps = { + protectedPath: boolean; + transcriptId: string; + userId: string | null; + shareMode: string; +}; + +const ShareLink = (props: ShareLinkProps) => { const [isCopied, setIsCopied] = useState(false); const inputRef = useRef(null); const [currentUrl, setCurrentUrl] = useState(""); + const requireLogin = featureEnabled("requireLogin"); + const [isOwner, setIsOwner] = useState(false); + const [shareMode, setShareMode] = useState(props.shareMode); + const api = getApi(props.protectedPath); + const userinfo = useFiefUserinfo(); useEffect(() => { setCurrentUrl(window.location.href); }, []); + useEffect(() => { + setIsOwner(!!(requireLogin && userinfo?.sub === props.userId)); + }, [userinfo, props.userId]); + const handleCopyClick = () => { if (inputRef.current) { let text_to_copy = inputRef.current.value; @@ -23,6 +45,16 @@ const ShareLink = () => { } }; + const updateShareMode = async (selectedShareMode: string) => { + if (!api) return; + const updatedTranscript = await api.v1TranscriptUpdate({ + transcriptId: props.transcriptId, + updateTranscript: { + shareMode: selectedShareMode, + }, + }); + setShareMode(updatedTranscript.shareMode); + }; const privacyEnabled = featureEnabled("privacy"); return ( @@ -30,18 +62,50 @@ const ShareLink = () => { className="p-2 md:p-4 rounded" style={{ background: "rgba(96, 165, 250, 0.2)" }} > - {privacyEnabled ? ( + {requireLogin && (

- You can share this link with others. Anyone with the link will have - access to the page, including the full audio recording, for the next 7 - days. -

- ) : ( -

- You can share this link with others. Anyone with the link will have - access to the page, including the full audio recording. + {shareMode === "private" && ( +

This transcript is only accessible by you.

+ )} + {shareMode === "semi-private" && ( +

This transcript is accessible by any authenticated users.

+ )} + {shareMode === "public" && ( +

This transcript is accessible by anyone.

+ )} + + {isOwner && api && ( +

+ +

+ )}

)} + {!requireLogin && ( + <> + {privacyEnabled ? ( +

+ You can share this link with others. Anyone with the link will + have access to the page, including the full audio recording, for + the next 7 days. +

+ ) : ( +

+ You can share this link with others. Anyone with the link will + have access to the page, including the full audio recording. +

+ )} + + )}
Date: Tue, 7 Nov 2023 18:41:51 +0100 Subject: [PATCH 02/27] www: edit from andreas feedback --- www/app/[domain]/transcripts/shareLink.tsx | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx index 44c57053..82ef52c9 100644 --- a/www/app/[domain]/transcripts/shareLink.tsx +++ b/www/app/[domain]/transcripts/shareLink.tsx @@ -65,13 +65,15 @@ const ShareLink = (props: ShareLinkProps) => { {requireLogin && (

{shareMode === "private" && ( -

This transcript is only accessible by you.

+

This transcript is private and can only be accessed by you.

)} {shareMode === "semi-private" && ( -

This transcript is accessible by any authenticated users.

+

+ This transcript is secure. Only authenticated users can access it. +

)} {shareMode === "public" && ( -

This transcript is accessible by anyone.

+

This transcript is public. Everyone can access it.

)} {isOwner && api && ( @@ -80,7 +82,7 @@ const ShareLink = (props: ShareLinkProps) => { className="select-search--top select-search" options={[ { name: "Private", value: "private" }, - { name: "Semi-private", value: "semi-private" }, + { name: "Secure", value: "semi-private" }, { name: "Public", value: "public" }, ]} value={shareMode} @@ -94,14 +96,14 @@ const ShareLink = (props: ShareLinkProps) => { <> {privacyEnabled ? (

- You can share this link with others. Anyone with the link will - have access to the page, including the full audio recording, for - the next 7 days. + Share this link to grant others access to this page. The link + includes the full audio recording and is valid for the next 7 + days.

) : (

- You can share this link with others. Anyone with the link will - have access to the page, including the full audio recording. + Share this link to allow others to view this page and listen to + the full audio recording.

)} From 86b3b3c0e4f61d98d1b4a0dbbb343286db503371 Mon Sep 17 00:00:00 2001 From: Sara Date: Mon, 13 Nov 2023 17:23:27 +0100 Subject: [PATCH 03/27] split recorder player --- .../transcripts/[transcriptId]/page.tsx | 10 +- www/app/[domain]/transcripts/player.tsx | 167 ++++++++++++++++++ www/app/[domain]/transcripts/recorder.tsx | 109 +----------- 3 files changed, 174 insertions(+), 112 deletions(-) create mode 100644 www/app/[domain]/transcripts/player.tsx diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 56201c3c..bebfe261 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -5,7 +5,6 @@ import useTopics from "../useTopics"; import useWaveform from "../useWaveform"; import useMp3 from "../useMp3"; import { TopicList } from "../topicList"; -import Recorder from "../recorder"; import { Topic } from "../webSocketTypes"; import React, { useState } from "react"; import "../../../styles/button.css"; @@ -13,6 +12,7 @@ import FinalSummary from "../finalSummary"; import ShareLink from "../shareLink"; import QRCode from "react-qr-code"; import TranscriptTitle from "../transcriptTitle"; +import Player from "../player"; type TranscriptDetails = { params: { @@ -62,14 +62,12 @@ export default function TranscriptDetails(details: TranscriptDetails) { /> )} {!waveform?.loading && ( - )}
diff --git a/www/app/[domain]/transcripts/player.tsx b/www/app/[domain]/transcripts/player.tsx new file mode 100644 index 00000000..6143836e --- /dev/null +++ b/www/app/[domain]/transcripts/player.tsx @@ -0,0 +1,167 @@ +import React, { useRef, useEffect, useState } from "react"; + +import WaveSurfer from "wavesurfer.js"; +import CustomRegionsPlugin from "../../lib/custom-plugins/regions"; + +import { formatTime } from "../../lib/time"; +import { Topic } from "./webSocketTypes"; +import { AudioWaveform } from "../../api"; +import { waveSurferStyles } from "../../styles/recorder"; + +type PlayerProps = { + topics: Topic[]; + useActiveTopic: [ + Topic | null, + React.Dispatch>, + ]; + waveform: AudioWaveform; + media: HTMLMediaElement; + mediaDuration: number; +}; + +export default function Player(props: PlayerProps) { + const waveformRef = useRef(null); + const [wavesurfer, setWavesurfer] = useState(null); + const [isPlaying, setIsPlaying] = useState(false); + const [currentTime, setCurrentTime] = useState(0); + const [waveRegions, setWaveRegions] = useState( + null, + ); + const [activeTopic, setActiveTopic] = props.useActiveTopic; + const topicsRef = useRef(props.topics); + + // Waveform setup + useEffect(() => { + if (waveformRef.current) { + // XXX duration is required to prevent recomputing peaks from audio + // However, the current waveform returns only the peaks, and no duration + // And the backend does not save duration properly. + // So at the moment, we deduct the duration from the topics. + // This is not ideal, but it works for now. + const _wavesurfer = WaveSurfer.create({ + container: waveformRef.current, + peaks: props.waveform.data, + hideScrollbar: true, + autoCenter: true, + barWidth: 2, + height: "auto", + duration: props.mediaDuration, + + ...waveSurferStyles.player, + }); + + // styling + const wsWrapper = _wavesurfer.getWrapper(); + wsWrapper.style.cursor = waveSurferStyles.playerStyle.cursor; + wsWrapper.style.backgroundColor = + waveSurferStyles.playerStyle.backgroundColor; + wsWrapper.style.borderRadius = waveSurferStyles.playerStyle.borderRadius; + + _wavesurfer.on("play", () => { + setIsPlaying(true); + }); + _wavesurfer.on("pause", () => { + setIsPlaying(false); + }); + _wavesurfer.on("timeupdate", setCurrentTime); + + setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); + + _wavesurfer.toggleInteraction(true); + + _wavesurfer.setMediaElement(props.media); + + setWavesurfer(_wavesurfer); + + return () => { + _wavesurfer.destroy(); + setIsPlaying(false); + setCurrentTime(0); + }; + } + }, []); + + useEffect(() => { + if (!wavesurfer) return; + if (!props.media) return; + wavesurfer.setMediaElement(props.media); + }, [props.media, wavesurfer]); + + useEffect(() => { + topicsRef.current = props.topics; + renderMarkers(); + }, [props.topics, waveRegions]); + + const renderMarkers = () => { + if (!waveRegions) return; + + waveRegions.clearRegions(); + + for (let topic of topicsRef.current) { + const content = document.createElement("div"); + content.setAttribute("style", waveSurferStyles.marker); + content.onmouseover = () => { + content.style.backgroundColor = + waveSurferStyles.markerHover.backgroundColor; + content.style.zIndex = "999"; + content.style.width = "300px"; + }; + content.onmouseout = () => { + content.setAttribute("style", waveSurferStyles.marker); + }; + content.textContent = topic.title; + + const region = waveRegions.addRegion({ + start: topic.timestamp, + content, + color: "f00", + drag: false, + }); + region.on("click", (e) => { + e.stopPropagation(); + setActiveTopic(topic); + wavesurfer?.setTime(region.start); + }); + } + }; + + useEffect(() => { + if (activeTopic) { + wavesurfer?.setTime(activeTopic.timestamp); + } + }, [activeTopic]); + + const handlePlayClick = () => { + wavesurfer?.playPause(); + }; + + const timeLabel = () => { + if (props.mediaDuration) + return `${formatTime(currentTime)}/${formatTime(props.mediaDuration)}`; + return ""; + }; + + return ( +
+
+
+
{timeLabel()}
+
+ + +
+ ); +} diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx index 8db32ff7..5b4420c4 100644 --- a/www/app/[domain]/transcripts/recorder.tsx +++ b/www/app/[domain]/transcripts/recorder.tsx @@ -6,11 +6,8 @@ import CustomRegionsPlugin from "../../lib/custom-plugins/regions"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faMicrophone } from "@fortawesome/free-solid-svg-icons"; -import { faDownload } from "@fortawesome/free-solid-svg-icons"; import { formatTime } from "../../lib/time"; -import { Topic } from "./webSocketTypes"; -import { AudioWaveform } from "../../api"; import AudioInputsDropdown from "./audioInputsDropdown"; import { Option } from "react-dropdown"; import { waveSurferStyles } from "../../styles/recorder"; @@ -19,17 +16,8 @@ import { useError } from "../../(errors)/errorContext"; type RecorderProps = { setStream?: React.Dispatch>; onStop?: () => void; - topics: Topic[]; getAudioStream?: (deviceId) => Promise; audioDevices?: Option[]; - useActiveTopic: [ - Topic | null, - React.Dispatch>, - ]; - waveform?: AudioWaveform | null; - isPastMeeting: boolean; - transcriptId?: string | null; - media?: HTMLMediaElement | null; mediaDuration?: number | null; }; @@ -38,7 +26,7 @@ export default function Recorder(props: RecorderProps) { const [wavesurfer, setWavesurfer] = useState(null); const [record, setRecord] = useState(null); const [isRecording, setIsRecording] = useState(false); - const [hasRecorded, setHasRecorded] = useState(props.isPastMeeting); + const [hasRecorded, setHasRecorded] = useState(false); const [isPlaying, setIsPlaying] = useState(false); const [currentTime, setCurrentTime] = useState(0); const [timeInterval, setTimeInterval] = useState(null); @@ -48,8 +36,6 @@ export default function Recorder(props: RecorderProps) { ); const [deviceId, setDeviceId] = useState(null); const [recordStarted, setRecordStarted] = useState(false); - const [activeTopic, setActiveTopic] = props.useActiveTopic; - const topicsRef = useRef(props.topics); const [showDevices, setShowDevices] = useState(false); const { setError } = useError(); @@ -73,8 +59,6 @@ export default function Recorder(props: RecorderProps) { if (!record.isRecording()) return; handleRecClick(); break; - case "^": - throw new Error("Unhandled Exception thrown by '^' shortcut"); case "(": location.href = "/login"; break; @@ -104,14 +88,8 @@ export default function Recorder(props: RecorderProps) { // Waveform setup useEffect(() => { if (waveformRef.current) { - // XXX duration is required to prevent recomputing peaks from audio - // However, the current waveform returns only the peaks, and no duration - // And the backend does not save duration properly. - // So at the moment, we deduct the duration from the topics. - // This is not ideal, but it works for now. const _wavesurfer = WaveSurfer.create({ container: waveformRef.current, - peaks: props.waveform?.data, hideScrollbar: true, autoCenter: true, barWidth: 2, @@ -121,10 +99,8 @@ export default function Recorder(props: RecorderProps) { ...waveSurferStyles.player, }); - if (!props.transcriptId) { - const _wshack: any = _wavesurfer; - _wshack.renderer.renderSingleCanvas = () => {}; - } + const _wshack: any = _wavesurfer; + _wshack.renderer.renderSingleCanvas = () => {}; // styling const wsWrapper = _wavesurfer.getWrapper(); @@ -144,12 +120,6 @@ export default function Recorder(props: RecorderProps) { setRecord(_wavesurfer.registerPlugin(RecordPlugin.create())); setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); - if (props.isPastMeeting) _wavesurfer.toggleInteraction(true); - - if (props.media) { - _wavesurfer.setMediaElement(props.media); - } - setWavesurfer(_wavesurfer); return () => { @@ -161,58 +131,6 @@ export default function Recorder(props: RecorderProps) { } }, []); - useEffect(() => { - if (!wavesurfer) return; - if (!props.media) return; - wavesurfer.setMediaElement(props.media); - }, [props.media, wavesurfer]); - - useEffect(() => { - topicsRef.current = props.topics; - if (!isRecording) renderMarkers(); - }, [props.topics, waveRegions]); - - const renderMarkers = () => { - if (!waveRegions) return; - - waveRegions.clearRegions(); - - for (let topic of topicsRef.current) { - const content = document.createElement("div"); - content.setAttribute("style", waveSurferStyles.marker); - content.onmouseover = () => { - content.style.backgroundColor = - waveSurferStyles.markerHover.backgroundColor; - content.style.zIndex = "999"; - content.style.width = "300px"; - }; - content.onmouseout = () => { - content.setAttribute("style", waveSurferStyles.marker); - }; - content.textContent = topic.title; - - const region = waveRegions.addRegion({ - start: topic.timestamp, - content, - color: "f00", - drag: false, - }); - region.on("click", (e) => { - e.stopPropagation(); - setActiveTopic(topic); - wavesurfer?.setTime(region.start); - }); - } - }; - - useEffect(() => { - if (!record) return; - - return record.on("stopRecording", () => { - renderMarkers(); - }); - }, [record]); - useEffect(() => { if (isRecording) { const interval = window.setInterval(() => { @@ -229,12 +147,6 @@ export default function Recorder(props: RecorderProps) { } }, [isRecording]); - useEffect(() => { - if (activeTopic) { - wavesurfer?.setTime(activeTopic.timestamp); - } - }, [activeTopic]); - const handleRecClick = async () => { if (!record) return console.log("no record"); @@ -320,7 +232,6 @@ export default function Recorder(props: RecorderProps) { if (!record) return; if (!destinationStream) return; if (props.setStream) props.setStream(destinationStream); - waveRegions?.clearRegions(); if (destinationStream) { record.startRecording(destinationStream); setIsRecording(true); @@ -379,23 +290,9 @@ export default function Recorder(props: RecorderProps) { } text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`} id="play-btn" onClick={handlePlayClick} - disabled={isRecording} > {isPlaying ? "Pause" : "Play"} - - {props.transcriptId && ( - - - - )} )} {!hasRecorded && ( From 14ebfa53a8c2f9abad2342ebe7d3ab4a018e52f5 Mon Sep 17 00:00:00 2001 From: Sara Date: Mon, 13 Nov 2023 17:33:12 +0100 Subject: [PATCH 04/27] wip loading and redirects --- .../transcripts/[transcriptId]/page.tsx | 38 +++++++++++++++---- .../[transcriptId]/record/page.tsx | 26 ++++++++----- www/app/[domain]/transcripts/useTranscript.ts | 27 ++++++++++--- www/app/[domain]/transcripts/useWebSockets.ts | 30 +++++++++------ 4 files changed, 87 insertions(+), 34 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index bebfe261..add5a824 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -6,7 +6,7 @@ import useWaveform from "../useWaveform"; import useMp3 from "../useMp3"; import { TopicList } from "../topicList"; import { Topic } from "../webSocketTypes"; -import React, { useState } from "react"; +import React, { useEffect, useState } from "react"; import "../../../styles/button.css"; import FinalSummary from "../finalSummary"; import ShareLink from "../shareLink"; @@ -31,7 +31,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { const useActiveTopic = useState(null); const mp3 = useMp3(protectedPath, transcriptId); - if (transcript?.error /** || topics?.error || waveform?.error **/) { + if (transcript?.error || topics?.error) { return ( { + const statusToRedirect = ["idle", "recording", "processing"]; + if (statusToRedirect.includes(transcript.response?.status)) { + const newUrl = "/transcripts/" + details.params.transcriptId + "/record"; + // Shallow redirection does not work on NextJS 13 + // https://github.com/vercel/next.js/discussions/48110 + // https://github.com/vercel/next.js/discussions/49540 + // router.push(newUrl, undefined, { shallow: true }); + history.replaceState({}, "", newUrl); + } + }, [transcript.response?.status]); + const fullTranscript = topics.topics ?.map((topic) => topic.transcript) .join("\n\n") .replace(/ +/g, " ") .trim() || ""; + console.log("calf full transcript"); return ( <> - {!transcriptId || transcript?.loading || topics?.loading ? ( + {transcript?.loading || topics?.loading ? ( ) : ( <> @@ -61,7 +74,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { transcriptId={transcript.response.id} /> )} - {!waveform?.loading && ( + {waveform.waveform && mp3.media ? ( + ) : mp3.error || waveform.error ? ( + "error loading this recording" + ) : ( + "Loading Recording" )}
+
- {transcript?.response?.longSummary && ( + {transcript.response.longSummary ? ( + ) : transcript.response.status == "processing" ? ( + "Loading Transcript" + ) : ( + "error final summary" )}
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 41a2d053..10297f5c 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -8,12 +8,12 @@ import { useWebSockets } from "../../useWebSockets"; import useAudioDevice from "../../useAudioDevice"; import "../../../../styles/button.css"; import { Topic } from "../../webSocketTypes"; -import getApi from "../../../../lib/getApi"; import LiveTrancription from "../../liveTranscription"; import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faGear } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; +import { useRouter } from "next/navigation"; type TranscriptDetails = { params: { @@ -45,21 +45,31 @@ const TranscriptRecord = (details: TranscriptDetails) => { const [hasRecorded, setHasRecorded] = useState(false); const [transcriptStarted, setTranscriptStarted] = useState(false); + const router = useRouter(); + useEffect(() => { if (!transcriptStarted && webSockets.transcriptText.length !== 0) setTranscriptStarted(true); }, [webSockets.transcriptText]); useEffect(() => { - if (transcript?.response?.longSummary) { - const newUrl = `/transcripts/${transcript.response.id}`; + const statusToRedirect = ["ended", "error"]; + console.log(webSockets.status, "hey"); + console.log(transcript.response, "ho"); + + //TODO if has no topic and is error, get back to new + if ( + statusToRedirect.includes(transcript.response?.status) || + statusToRedirect.includes(webSockets.status.value) + ) { + const newUrl = "/transcripts/" + details.params.transcriptId; // Shallow redirection does not work on NextJS 13 // https://github.com/vercel/next.js/discussions/48110 // https://github.com/vercel/next.js/discussions/49540 - // router.push(newUrl, undefined, { shallow: true }); - history.replaceState({}, "", newUrl); + router.push(newUrl, undefined); + // history.replaceState({}, "", newUrl); } - }); + }, [webSockets.status.value, transcript.response?.status]); useEffect(() => { lockWakeState(); @@ -77,10 +87,7 @@ const TranscriptRecord = (details: TranscriptDetails) => { setHasRecorded(true); webRTC?.send(JSON.stringify({ cmd: "STOP" })); }} - topics={webSockets.topics} getAudioStream={getAudioStream} - useActiveTopic={useActiveTopic} - isPastMeeting={false} audioDevices={audioDevices} /> @@ -128,6 +135,7 @@ const TranscriptRecord = (details: TranscriptDetails) => { couple of minutes. Please do not navigate away from the page during this time.

+ {/* TODO If login required remove last sentence */}
)} diff --git a/www/app/[domain]/transcripts/useTranscript.ts b/www/app/[domain]/transcripts/useTranscript.ts index af60cd3b..987e57f3 100644 --- a/www/app/[domain]/transcripts/useTranscript.ts +++ b/www/app/[domain]/transcripts/useTranscript.ts @@ -5,16 +5,28 @@ import { useError } from "../../(errors)/errorContext"; import getApi from "../../lib/getApi"; import { shouldShowError } from "../../lib/errorUtils"; -type Transcript = { - response: GetTranscript | null; - loading: boolean; - error: Error | null; +type ErrorTranscript = { + error: Error; + loading: false; + response: any; +}; + +type LoadingTranscript = { + response: any; + loading: true; + error: false; +}; + +type SuccessTranscript = { + response: GetTranscript; + loading: false; + error: null; }; const useTranscript = ( protectedPath: boolean, id: string | null, -): Transcript => { +): ErrorTranscript | LoadingTranscript | SuccessTranscript => { const [response, setResponse] = useState(null); const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); @@ -46,7 +58,10 @@ const useTranscript = ( }); }, [id, !api]); - return { response, loading, error }; + return { response, loading, error } as + | ErrorTranscript + | LoadingTranscript + | SuccessTranscript; }; export default useTranscript; diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index bcf6b163..cb74f86d 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -4,9 +4,10 @@ import { useError } from "../../(errors)/errorContext"; import { useRouter } from "next/navigation"; import { DomainContext } from "../domainContext"; -type UseWebSockets = { +export type UseWebSockets = { transcriptText: string; translateText: string; + title: string; topics: Topic[]; finalSummary: FinalSummary; status: Status; @@ -15,6 +16,7 @@ type UseWebSockets = { export const useWebSockets = (transcriptId: string | null): UseWebSockets => { const [transcriptText, setTranscriptText] = useState(""); const [translateText, setTranslateText] = useState(""); + const [title, setTitle] = useState(""); const [textQueue, setTextQueue] = useState([]); const [translationQueue, setTranslationQueue] = useState([]); const [isProcessing, setIsProcessing] = useState(false); @@ -24,7 +26,6 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { }); const [status, setStatus] = useState({ value: "initial" }); const { setError } = useError(); - const router = useRouter(); const { websocket_url } = useContext(DomainContext); @@ -294,7 +295,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { if (!transcriptId) return; const url = `${websocket_url}/v1/transcripts/${transcriptId}/events`; - const ws = new WebSocket(url); + let ws = new WebSocket(url); ws.onopen = () => { console.debug("WebSocket connection opened"); @@ -343,24 +344,23 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { case "FINAL_TITLE": console.debug("FINAL_TITLE event:", message.data); + if (message.data) { + setTitle(message.data.title); + } break; case "STATUS": console.log("STATUS event:", message.data); - if (message.data.value === "ended") { - const newUrl = "/transcripts/" + transcriptId; - router.push(newUrl); - console.debug("FINAL_LONG_SUMMARY event:", message.data); - } if (message.data.value === "error") { - const newUrl = "/transcripts/" + transcriptId; - router.push(newUrl); setError( Error("Websocket error status"), "There was an error processing this meeting.", ); } setStatus(message.data); + if (message.data.value === "ended") { + ws.close(); + } break; default: @@ -388,8 +388,16 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { default: setError( new Error(`WebSocket closed unexpectedly with code: ${event.code}`), + "Disconnected", ); } + console.log( + "Socket is closed. Reconnect will be attempted in 1 second.", + event.reason, + ); + setTimeout(function () { + ws = new WebSocket(url); + }, 1000); }; return () => { @@ -397,5 +405,5 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { }; }, [transcriptId]); - return { transcriptText, translateText, topics, finalSummary, status }; + return { transcriptText, translateText, topics, finalSummary, title, status }; }; From e98f1bf4bc7bed06fddecd22537ea819761b83b7 Mon Sep 17 00:00:00 2001 From: Sara Date: Mon, 13 Nov 2023 18:33:24 +0100 Subject: [PATCH 05/27] loading and redirecting front-end --- .../transcripts/[transcriptId]/page.tsx | 18 +++++-- .../[transcriptId]/record/page.tsx | 50 +++++++++++++------ www/app/[domain]/transcripts/finalSummary.tsx | 2 +- www/app/[domain]/transcripts/recorder.tsx | 13 +++-- www/app/[domain]/transcripts/useWebSockets.ts | 14 +++++- .../[domain]/transcripts/waveformLoading.tsx | 11 ++++ 6 files changed, 77 insertions(+), 31 deletions(-) create mode 100644 www/app/[domain]/transcripts/waveformLoading.tsx diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index add5a824..079b41e8 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -13,6 +13,7 @@ import ShareLink from "../shareLink"; import QRCode from "react-qr-code"; import TranscriptTitle from "../transcriptTitle"; import Player from "../player"; +import WaveformLoading from "../waveformLoading"; type TranscriptDetails = { params: { @@ -83,9 +84,9 @@ export default function TranscriptDetails(details: TranscriptDetails) { mediaDuration={transcript.response.duration} /> ) : mp3.error || waveform.error ? ( - "error loading this recording" +
"error loading this recording"
) : ( - "Loading Recording" + )}
@@ -104,10 +105,17 @@ export default function TranscriptDetails(details: TranscriptDetails) { summary={transcript.response.longSummary} transcriptId={transcript.response.id} /> - ) : transcript.response.status == "processing" ? ( - "Loading Transcript" ) : ( - "error final summary" +
+ {transcript.response.status == "processing" ? ( +

Loading Transcript

+ ) : ( +

+ There was an error generating the final summary, please + come back later +

+ )} +
)} diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 10297f5c..36f4bbe9 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -11,9 +11,12 @@ import { Topic } from "../../webSocketTypes"; import LiveTrancription from "../../liveTranscription"; import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; -import { faGear } from "@fortawesome/free-solid-svg-icons"; +import { faGear, faSpinner } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; import { useRouter } from "next/navigation"; +import Player from "../../player"; +import useMp3 from "../../useMp3"; +import WaveformLoading from "../../waveformLoading"; type TranscriptDetails = { params: { @@ -42,8 +45,10 @@ const TranscriptRecord = (details: TranscriptDetails) => { const { audioDevices, getAudioStream } = useAudioDevice(); - const [hasRecorded, setHasRecorded] = useState(false); + const [recordedTime, setRecordedTime] = useState(0); + const [startTime, setStartTime] = useState(0); const [transcriptStarted, setTranscriptStarted] = useState(false); + const mp3 = useMp3(true, ""); const router = useRouter(); @@ -54,8 +59,6 @@ const TranscriptRecord = (details: TranscriptDetails) => { useEffect(() => { const statusToRedirect = ["ended", "error"]; - console.log(webSockets.status, "hey"); - console.log(transcript.response, "ho"); //TODO if has no topic and is error, get back to new if ( @@ -80,16 +83,31 @@ const TranscriptRecord = (details: TranscriptDetails) => { return ( <> - { - setStream(null); - setHasRecorded(true); - webRTC?.send(JSON.stringify({ cmd: "STOP" })); - }} - getAudioStream={getAudioStream} - audioDevices={audioDevices} - /> + {webSockets.waveform && mp3.media ? ( + + ) : recordedTime ? ( + + ) : ( + { + setStream(null); + setRecordedTime(Date.now() - startTime); + webRTC?.send(JSON.stringify({ cmd: "STOP" })); + }} + onRecord={() => { + setStartTime(Date.now()); + }} + getAudioStream={getAudioStream} + audioDevices={audioDevices} + /> + )}
{
- {!hasRecorded ? ( + {!recordedTime ? ( <> {transcriptStarted && (

Transcription

@@ -135,7 +153,7 @@ const TranscriptRecord = (details: TranscriptDetails) => { couple of minutes. Please do not navigate away from the page during this time.

- {/* TODO If login required remove last sentence */} + {/* NTH If login required remove last sentence */}
)} diff --git a/www/app/[domain]/transcripts/finalSummary.tsx b/www/app/[domain]/transcripts/finalSummary.tsx index 463f6100..e0d0f1c9 100644 --- a/www/app/[domain]/transcripts/finalSummary.tsx +++ b/www/app/[domain]/transcripts/finalSummary.tsx @@ -87,7 +87,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx index 5b4420c4..e7c016a7 100644 --- a/www/app/[domain]/transcripts/recorder.tsx +++ b/www/app/[domain]/transcripts/recorder.tsx @@ -14,11 +14,11 @@ import { waveSurferStyles } from "../../styles/recorder"; import { useError } from "../../(errors)/errorContext"; type RecorderProps = { - setStream?: React.Dispatch>; - onStop?: () => void; - getAudioStream?: (deviceId) => Promise; - audioDevices?: Option[]; - mediaDuration?: number | null; + setStream: React.Dispatch>; + onStop: () => void; + onRecord?: () => void; + getAudioStream: (deviceId) => Promise; + audioDevices: Option[]; }; export default function Recorder(props: RecorderProps) { @@ -94,7 +94,6 @@ export default function Recorder(props: RecorderProps) { autoCenter: true, barWidth: 2, height: "auto", - duration: props.mediaDuration || 1, ...waveSurferStyles.player, }); @@ -161,10 +160,10 @@ export default function Recorder(props: RecorderProps) { setScreenMediaStream(null); setDestinationStream(null); } else { + if (props.onRecord) props.onRecord(); const stream = await getCurrentStream(); if (props.setStream) props.setStream(stream); - waveRegions?.clearRegions(); if (stream) { await record.startRecording(stream); setIsRecording(true); diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index cb74f86d..3f3d20fc 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -1,8 +1,8 @@ import { useContext, useEffect, useState } from "react"; import { Topic, FinalSummary, Status } from "./webSocketTypes"; import { useError } from "../../(errors)/errorContext"; -import { useRouter } from "next/navigation"; import { DomainContext } from "../domainContext"; +import { AudioWaveform } from "../../api"; export type UseWebSockets = { transcriptText: string; @@ -11,6 +11,7 @@ export type UseWebSockets = { topics: Topic[]; finalSummary: FinalSummary; status: Status; + waveform: AudioWaveform | null; }; export const useWebSockets = (transcriptId: string | null): UseWebSockets => { @@ -21,6 +22,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { const [translationQueue, setTranslationQueue] = useState([]); const [isProcessing, setIsProcessing] = useState(false); const [topics, setTopics] = useState([]); + const [waveform, setWaveForm] = useState(null); const [finalSummary, setFinalSummary] = useState({ summary: "", }); @@ -405,5 +407,13 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { }; }, [transcriptId]); - return { transcriptText, translateText, topics, finalSummary, title, status }; + return { + transcriptText, + translateText, + topics, + finalSummary, + title, + status, + waveform, + }; }; diff --git a/www/app/[domain]/transcripts/waveformLoading.tsx b/www/app/[domain]/transcripts/waveformLoading.tsx new file mode 100644 index 00000000..68e0c80f --- /dev/null +++ b/www/app/[domain]/transcripts/waveformLoading.tsx @@ -0,0 +1,11 @@ +import { faSpinner } from "@fortawesome/free-solid-svg-icons"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; + +export default () => ( +
+ +
+); From 1fc261a66999df014820f07836d6300977977fe4 Mon Sep 17 00:00:00 2001 From: Sara Date: Wed, 15 Nov 2023 20:30:00 +0100 Subject: [PATCH 06/27] try to move waveform to pipeline --- server/reflector/db/transcripts.py | 25 +++++--------- .../reflector/pipelines/main_live_pipeline.py | 26 +++++++++++++-- .../processors/audio_waveform_processor.py | 33 +++++++++++++++++++ server/reflector/views/transcripts.py | 3 +- server/tests/test_transcripts_rtc_ws.py | 4 +++ www/app/lib/edgeConfig.ts | 4 +-- 6 files changed, 72 insertions(+), 23 deletions(-) create mode 100644 server/reflector/processors/audio_waveform_processor.py diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 6ac2e32a..f0dbc277 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform transcripts = sqlalchemy.Table( "transcript", @@ -79,6 +78,14 @@ class TranscriptFinalTitle(BaseModel): title: str +class TranscriptDuration(BaseModel): + duration: float + + +class TranscriptWaveform(BaseModel): + waveform: list[float] + + class TranscriptEvent(BaseModel): event: str data: dict @@ -118,22 +125,6 @@ class Transcript(BaseModel): def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform - def unlink(self): self.data_path.unlink(missing_ok=True) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 8c78c48f..b0576b92 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -21,11 +21,13 @@ from pydantic import BaseModel from reflector.app import app from reflector.db.transcripts import ( Transcript, + TranscriptDuration, TranscriptFinalLongSummary, TranscriptFinalShortSummary, TranscriptFinalTitle, TranscriptText, TranscriptTopic, + TranscriptWaveform, transcripts_controller, ) from reflector.logger import logger @@ -45,6 +47,7 @@ from reflector.processors import ( TranscriptTopicDetectorProcessor, TranscriptTranslatorProcessor, ) +from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import ( TitleSummaryWithId as TitleSummaryWithIdProcessorType, @@ -230,15 +233,29 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) - async def on_duration(self, duration: float): + async def on_duration(self, data): async with self.transaction(): + duration = TranscriptDuration(duration=data) + transcript = await self.get_transcript() await transcripts_controller.update( transcript, { - "duration": duration, + "duration": duration.duration, }, ) + return await transcripts_controller.append_event( + transcript=transcript, event="DURATION", data=duration + ) + + async def on_waveform(self, data): + waveform = TranscriptWaveform(waveform=data) + + transcript = await self.get_transcript() + + return await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform + ) class PipelineMainLive(PipelineMainBase): @@ -266,6 +283,11 @@ class PipelineMainLive(PipelineMainBase): BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + AudioWaveformProcessor( + audio_path=transcript.audio_mp3_filename, + waveform_path=transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), ] ), ] diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py new file mode 100644 index 00000000..acce904a --- /dev/null +++ b/server/reflector/processors/audio_waveform_processor.py @@ -0,0 +1,33 @@ +import json +from pathlib import Path + +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary +from reflector.utils.audio_waveform import get_audio_waveform + + +class AudioWaveformProcessor(Processor): + """ + Write the waveform for the final audio + """ + + INPUT_TYPE = TitleSummary + + def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs): + super().__init__(**kwargs) + if isinstance(audio_path, str): + audio_path = Path(audio_path) + if audio_path.suffix not in (".mp3", ".wav"): + raise ValueError("Only mp3 and wav files are supported") + self.audio_path = audio_path + self.waveform_path = waveform_path + + async def _push(self, _data): + self.waveform_path.parent.mkdir(parents=True, exist_ok=True) + self.logger.info("Waveform Processing Started") + waveform = get_audio_waveform(path=self.audio_path, segments_count=255) + + with open(self.waveform_path, "w") as fd: + json.dump(waveform, fd) + self.logger.info("Waveform Processing Finished") + await self.emit(waveform, name="waveform") diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 5de9ced3..77e4b0b7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -22,7 +22,6 @@ from reflector.db.transcripts import ( from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings from reflector.ws_manager import get_ws_manager -from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import RtcOffer, rtc_offer_base @@ -261,7 +260,7 @@ async def transcript_get_audio_waveform( if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") - await run_in_threadpool(transcript.convert_audio_to_waveform) + # await run_in_threadpool(transcript.convert_audio_to_waveform) return transcript.audio_waveform diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index cf2ea304..65660a5e 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -182,6 +182,10 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("FINAL_TITLE")] assert ev["data"]["title"] == "LLM TITLE" + assert "WAVEFORM" in eventnames + ev = events[eventnames.index("FINAL_TITLE")] + assert ev["data"]["title"] == "LLM TITLE" + # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] assert statuses.index("recording") < statuses.index("processing") diff --git a/www/app/lib/edgeConfig.ts b/www/app/lib/edgeConfig.ts index 5527121a..1140e555 100644 --- a/www/app/lib/edgeConfig.ts +++ b/www/app/lib/edgeConfig.ts @@ -3,9 +3,9 @@ import { isDevelopment } from "./utils"; const localConfig = { features: { - requireLogin: true, + requireLogin: false, privacy: true, - browse: true, + browse: false, }, api_url: "http://127.0.0.1:1250", websocket_url: "ws://127.0.0.1:1250", From 8b8e92ceac7a121bd16909ea2d8401a889ae44cb Mon Sep 17 00:00:00 2001 From: Sara Date: Wed, 15 Nov 2023 20:34:45 +0100 Subject: [PATCH 07/27] replace history instead of pushing --- www/app/[domain]/transcripts/[transcriptId]/record/page.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 36f4bbe9..147f8827 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -69,9 +69,9 @@ const TranscriptRecord = (details: TranscriptDetails) => { // Shallow redirection does not work on NextJS 13 // https://github.com/vercel/next.js/discussions/48110 // https://github.com/vercel/next.js/discussions/49540 - router.push(newUrl, undefined); + router.replace(newUrl); // history.replaceState({}, "", newUrl); - } + } // history.replaceState({}, "", newUrl); }, [webSockets.status.value, transcript.response?.status]); useEffect(() => { From a846e38fbdeb77aad23282a8068b08e4cd6b3921 Mon Sep 17 00:00:00 2001 From: Sara Date: Fri, 17 Nov 2023 13:38:32 +0100 Subject: [PATCH 08/27] fix waveform in pipeline --- .../reflector/pipelines/main_live_pipeline.py | 15 +++++++++------ .../processors/audio_waveform_processor.py | 5 ++++- server/tests/test_transcripts_audio_download.py | 12 ------------ server/tests/test_transcripts_rtc_ws.py | 17 ++++++++++++----- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index b0576b92..fece6da5 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -233,6 +233,7 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) + @broadcast_to_sockets async def on_duration(self, data): async with self.transaction(): duration = TranscriptDuration(duration=data) @@ -248,14 +249,16 @@ class PipelineMainBase(PipelineRunner): transcript=transcript, event="DURATION", data=duration ) + @broadcast_to_sockets async def on_waveform(self, data): - waveform = TranscriptWaveform(waveform=data) + async with self.transaction(): + waveform = TranscriptWaveform(waveform=data) - transcript = await self.get_transcript() + transcript = await self.get_transcript() - return await transcripts_controller.append_event( - transcript=transcript, event="WAVEFORM", data=waveform - ) + return await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform + ) class PipelineMainLive(PipelineMainBase): @@ -283,7 +286,7 @@ class PipelineMainLive(PipelineMainBase): BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - AudioWaveformProcessor( + AudioWaveformProcessor.as_threaded( audio_path=transcript.audio_mp3_filename, waveform_path=transcript.audio_waveform_filename, on_waveform=self.on_waveform, diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py index acce904a..f1a24ffd 100644 --- a/server/reflector/processors/audio_waveform_processor.py +++ b/server/reflector/processors/audio_waveform_processor.py @@ -22,7 +22,7 @@ class AudioWaveformProcessor(Processor): self.audio_path = audio_path self.waveform_path = waveform_path - async def _push(self, _data): + async def _flush(self): self.waveform_path.parent.mkdir(parents=True, exist_ok=True) self.logger.info("Waveform Processing Started") waveform = get_audio_waveform(path=self.audio_path, segments_count=255) @@ -31,3 +31,6 @@ class AudioWaveformProcessor(Processor): json.dump(waveform, fd) self.logger.info("Waveform Processing Finished") await self.emit(waveform, name="waveform") + + async def _push(_self, _data): + return diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 69ae5f65..28f83fff 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek( assert response.status_code == 206 assert response.headers["content-type"] == content_type assert response.headers["content-range"].startswith("bytes 100-") - - -@pytest.mark.asyncio -async def test_transcript_audio_download_waveform(fake_transcript): - from reflector.app import app - - ac = AsyncClient(app=app, base_url="http://test/v1") - response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform") - assert response.status_code == 200 - assert response.headers["content-type"] == "application/json" - assert isinstance(response.json()["data"], list) - assert len(response.json()["data"]) >= 255 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 65660a5e..b33b1db5 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -183,8 +183,14 @@ async def test_transcript_rtc_and_websocket( assert ev["data"]["title"] == "LLM TITLE" assert "WAVEFORM" in eventnames - ev = events[eventnames.index("FINAL_TITLE")] - assert ev["data"]["title"] == "LLM TITLE" + ev = events[eventnames.index("WAVEFORM")] + assert isinstance(ev["data"]["waveform"], list) + assert len(ev["data"]["waveform"]) >= 250 + waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform") + assert waveform_resp.status_code == 200 + assert waveform_resp.headers["content-type"] == "application/json" + assert isinstance(waveform_resp.json()["data"], list) + assert len(waveform_resp.json()["data"]) >= 250 # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] @@ -197,11 +203,12 @@ async def test_transcript_rtc_and_websocket( # check on the latest response that the audio duration is > 0 assert resp.json()["duration"] > 0 + assert "DURATION" in eventnames # check that audio/mp3 is available - resp = await ac.get(f"/transcripts/{tid}/audio/mp3") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/mpeg" + audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert audio_resp.status_code == 200 + assert audio_resp.headers["Content-Type"] == "audio/mpeg" @pytest.mark.usefixtures("celery_session_app") From a816e00769faba40098acd36085469a23089f614 Mon Sep 17 00:00:00 2001 From: Sara Date: Fri, 17 Nov 2023 15:16:53 +0100 Subject: [PATCH 09/27] get waveform from socket --- .../transcripts/[transcriptId]/page.tsx | 7 ++--- .../[transcriptId]/record/page.tsx | 16 ++++++---- www/app/[domain]/transcripts/player.tsx | 5 ++-- www/app/[domain]/transcripts/useMp3.ts | 29 ++++++++----------- www/app/[domain]/transcripts/useWebSockets.ts | 21 +++++++++++++- 5 files changed, 48 insertions(+), 30 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 079b41e8..b0986e94 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -30,7 +30,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { const topics = useTopics(protectedPath, transcriptId); const waveform = useWaveform(protectedPath, transcriptId); const useActiveTopic = useState(null); - const mp3 = useMp3(protectedPath, transcriptId); + const mp3 = useMp3(transcriptId); if (transcript?.error || topics?.error) { return ( @@ -59,7 +59,6 @@ export default function TranscriptDetails(details: TranscriptDetails) { .join("\n\n") .replace(/ +/g, " ") .trim() || ""; - console.log("calf full transcript"); return ( <> @@ -79,11 +78,11 @@ export default function TranscriptDetails(details: TranscriptDetails) { - ) : mp3.error || waveform.error ? ( + ) : waveform.error ? (
"error loading this recording"
) : ( diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 147f8827..2c5b73e0 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -11,11 +11,11 @@ import { Topic } from "../../webSocketTypes"; import LiveTrancription from "../../liveTranscription"; import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; -import { faGear, faSpinner } from "@fortawesome/free-solid-svg-icons"; +import { faGear } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; import { useRouter } from "next/navigation"; import Player from "../../player"; -import useMp3 from "../../useMp3"; +import useMp3, { Mp3Response } from "../../useMp3"; import WaveformLoading from "../../waveformLoading"; type TranscriptDetails = { @@ -48,7 +48,7 @@ const TranscriptRecord = (details: TranscriptDetails) => { const [recordedTime, setRecordedTime] = useState(0); const [startTime, setStartTime] = useState(0); const [transcriptStarted, setTranscriptStarted] = useState(false); - const mp3 = useMp3(true, ""); + let mp3 = useMp3(details.params.transcriptId, true); const router = useRouter(); @@ -74,6 +74,12 @@ const TranscriptRecord = (details: TranscriptDetails) => { } // history.replaceState({}, "", newUrl); }, [webSockets.status.value, transcript.response?.status]); + useEffect(() => { + if (webSockets.duration) { + mp3.getNow(); + } + }, [webSockets.duration]); + useEffect(() => { lockWakeState(); return () => { @@ -83,13 +89,13 @@ const TranscriptRecord = (details: TranscriptDetails) => { return ( <> - {webSockets.waveform && mp3.media ? ( + {webSockets.waveform && webSockets.duration && mp3?.media ? ( ) : recordedTime ? ( diff --git a/www/app/[domain]/transcripts/player.tsx b/www/app/[domain]/transcripts/player.tsx index 6143836e..02151a68 100644 --- a/www/app/[domain]/transcripts/player.tsx +++ b/www/app/[domain]/transcripts/player.tsx @@ -14,7 +14,7 @@ type PlayerProps = { Topic | null, React.Dispatch>, ]; - waveform: AudioWaveform; + waveform: AudioWaveform["data"]; media: HTMLMediaElement; mediaDuration: number; }; @@ -29,7 +29,6 @@ export default function Player(props: PlayerProps) { ); const [activeTopic, setActiveTopic] = props.useActiveTopic; const topicsRef = useRef(props.topics); - // Waveform setup useEffect(() => { if (waveformRef.current) { @@ -40,7 +39,7 @@ export default function Player(props: PlayerProps) { // This is not ideal, but it works for now. const _wavesurfer = WaveSurfer.create({ container: waveformRef.current, - peaks: props.waveform.data, + peaks: props.waveform, hideScrollbar: true, autoCenter: true, barWidth: 2, diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts index 570a6a25..23249f94 100644 --- a/www/app/[domain]/transcripts/useMp3.ts +++ b/www/app/[domain]/transcripts/useMp3.ts @@ -1,24 +1,19 @@ import { useContext, useEffect, useState } from "react"; -import { useError } from "../../(errors)/errorContext"; import { DomainContext } from "../domainContext"; import getApi from "../../lib/getApi"; import { useFiefAccessTokenInfo } from "@fief/fief/build/esm/nextjs/react"; -import { shouldShowError } from "../../lib/errorUtils"; -type Mp3Response = { - url: string | null; +export type Mp3Response = { media: HTMLMediaElement | null; loading: boolean; - error: Error | null; + getNow: () => void; }; -const useMp3 = (protectedPath: boolean, id: string): Mp3Response => { - const [url, setUrl] = useState(null); +const useMp3 = (id: string, waiting?: boolean): Mp3Response => { const [media, setMedia] = useState(null); + const [later, setLater] = useState(waiting); const [loading, setLoading] = useState(false); - const [error, setErrorState] = useState(null); - const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(true); const { api_url } = useContext(DomainContext); const accessTokenInfo = useFiefAccessTokenInfo(); const [serviceWorkerReady, setServiceWorkerReady] = useState(false); @@ -42,8 +37,8 @@ const useMp3 = (protectedPath: boolean, id: string): Mp3Response => { }); }, [navigator.serviceWorker, serviceWorkerReady, accessTokenInfo]); - const getMp3 = (id: string) => { - if (!id || !api) return; + useEffect(() => { + if (!id || !api || later) return; // createa a audio element and set the source setLoading(true); @@ -53,13 +48,13 @@ const useMp3 = (protectedPath: boolean, id: string): Mp3Response => { audioElement.preload = "auto"; setMedia(audioElement); setLoading(false); + }, [id, api, later]); + + const getNow = () => { + setLater(false); }; - useEffect(() => { - getMp3(id); - }, [id, api]); - - return { url, media, loading, error }; + return { media, loading, getNow }; }; export default useMp3; diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index 3f3d20fc..f33a1347 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -11,7 +11,8 @@ export type UseWebSockets = { topics: Topic[]; finalSummary: FinalSummary; status: Status; - waveform: AudioWaveform | null; + waveform: AudioWaveform["data"] | null; + duration: number | null; }; export const useWebSockets = (transcriptId: string | null): UseWebSockets => { @@ -23,6 +24,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { const [isProcessing, setIsProcessing] = useState(false); const [topics, setTopics] = useState([]); const [waveform, setWaveForm] = useState(null); + const [duration, setDuration] = useState(null); const [finalSummary, setFinalSummary] = useState({ summary: "", }); @@ -351,6 +353,22 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { } break; + case "WAVEFORM": + console.debug( + "WAVEFORM event length:", + message.data.waveform.length, + ); + if (message.data) { + setWaveForm(message.data.waveform); + } + break; + case "DURATION": + console.debug("DURATION event:", message.data); + if (message.data) { + setDuration(message.data.duration); + } + break; + case "STATUS": console.log("STATUS event:", message.data); if (message.data.value === "error") { @@ -415,5 +433,6 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { title, status, waveform, + duration, }; }; From c40e0970b7d9dfe4d3df2cb159b7ffbecfece077 Mon Sep 17 00:00:00 2001 From: Sara Date: Mon, 20 Nov 2023 20:13:18 +0100 Subject: [PATCH 10/27] review fixes --- server/reflector/views/transcripts.py | 2 -- www/app/[domain]/transcripts/useWebSockets.ts | 17 +++++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 77e4b0b7..6909b8ae 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -260,8 +260,6 @@ async def transcript_get_audio_waveform( if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") - # await run_in_threadpool(transcript.convert_audio_to_waveform) - return transcript.audio_waveform diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index f33a1347..f289adbb 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -402,22 +402,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { console.debug("WebSocket connection closed"); switch (event.code) { case 1000: // Normal Closure: - case 1001: // Going Away: - case 1005: - break; default: setError( new Error(`WebSocket closed unexpectedly with code: ${event.code}`), "Disconnected", ); + console.log( + "Socket is closed. Reconnect will be attempted in 1 second.", + event.reason, + ); + setTimeout(function () { + ws = new WebSocket(url); + }, 1000); } - console.log( - "Socket is closed. Reconnect will be attempted in 1 second.", - event.reason, - ); - setTimeout(function () { - ws = new WebSocket(url); - }, 1000); }; return () => { From f38dad3ad400830f5c35c78990ecce4142ad78ab Mon Sep 17 00:00:00 2001 From: Sara Date: Wed, 22 Nov 2023 12:25:21 +0100 Subject: [PATCH 11/27] small fixes and start auth fix --- www/app/[domain]/layout.tsx | 143 ++++++++-------- .../transcripts/[transcriptId]/page.tsx | 162 +++++++++--------- www/app/[domain]/transcripts/useMp3.ts | 14 +- www/app/[domain]/transcripts/useWaveform.ts | 5 +- www/app/[domain]/transcripts/useWebSockets.ts | 1 + www/app/lib/edgeConfig.ts | 4 +- www/app/lib/errorUtils.ts | 5 +- www/app/lib/fief.ts | 2 +- 8 files changed, 172 insertions(+), 164 deletions(-) diff --git a/www/app/[domain]/layout.tsx b/www/app/[domain]/layout.tsx index dbe5ed11..3e881ac3 100644 --- a/www/app/[domain]/layout.tsx +++ b/www/app/[domain]/layout.tsx @@ -11,6 +11,7 @@ import About from "../(aboutAndPrivacy)/about"; import Privacy from "../(aboutAndPrivacy)/privacy"; import { DomainContextProvider } from "./domainContext"; import { getConfig } from "../lib/edgeConfig"; +import { ErrorBoundary } from "@sentry/nextjs"; const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] }); @@ -76,80 +77,82 @@ export default async function RootLayout({ children, params }: LayoutProps) { - - -
-
- {/* Logo on the left */} - - Reflector -
-

- Reflector -

-

- Capture the signal, not the noise -

-
- -
- {/* Text link on the right */} + "something went really wrong"

}> + + +
+
+ {/* Logo on the left */} - Create + Reflector +
+

+ Reflector +

+

+ Capture the signal, not the noise +

+
- {browse ? ( - <> -  Â·  - - Browse - - - ) : ( - <> - )} -  Â·  - - {privacy ? ( - <> -  Â·  - - - ) : ( - <> - )} - {requireLogin ? ( - <> -  Â·  - - - ) : ( - <> - )} -
-
+
+ {/* Text link on the right */} + + Create + + {browse ? ( + <> +  Â·  + + Browse + + + ) : ( + <> + )} +  Â·  + + {privacy ? ( + <> +  Â·  + + + ) : ( + <> + )} + {requireLogin ? ( + <> +  Â·  + + + ) : ( + <> + )} +
+ - {children} -
-
+ {children} +
+ + diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 734f2609..23634b95 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -14,6 +14,7 @@ import QRCode from "react-qr-code"; import TranscriptTitle from "../transcriptTitle"; import Player from "../player"; import WaveformLoading from "../waveformLoading"; +import { useRouter } from "next/navigation"; type TranscriptDetails = { params: { @@ -21,10 +22,11 @@ type TranscriptDetails = { }; }; -const protectedPath = true; +const protectedPath = false; export default function TranscriptDetails(details: TranscriptDetails) { const transcriptId = details.params.transcriptId; + const router = useRouter(); const transcript = useTranscript(protectedPath, transcriptId); const topics = useTopics(protectedPath, transcriptId); @@ -32,15 +34,6 @@ export default function TranscriptDetails(details: TranscriptDetails) { const useActiveTopic = useState(null); const mp3 = useMp3(transcriptId); - if (transcript?.error || topics?.error) { - return ( - - ); - } - useEffect(() => { const statusToRedirect = ["idle", "recording", "processing"]; if (statusToRedirect.includes(transcript.response?.status)) { @@ -48,8 +41,8 @@ export default function TranscriptDetails(details: TranscriptDetails) { // Shallow redirection does not work on NextJS 13 // https://github.com/vercel/next.js/discussions/48110 // https://github.com/vercel/next.js/discussions/49540 - // router.push(newUrl, undefined, { shallow: true }); - history.replaceState({}, "", newUrl); + router.push(newUrl, undefined); + // history.replaceState({}, "", newUrl); } }, [transcript.response?.status]); @@ -60,85 +53,92 @@ export default function TranscriptDetails(details: TranscriptDetails) { .replace(/ +/g, " ") .trim() || ""; + if (transcript.error || topics?.error) { + return ( + + ); + } + + if (transcript?.loading || topics?.loading) { + return ; + } + return ( <> - {transcript?.loading || topics?.loading ? ( - - ) : ( - <> -
- {transcript?.response?.title && ( - + {transcript?.response?.title && ( + + )} + {waveform.waveform && mp3.media ? ( + + ) : waveform.error ? ( +
"error loading this recording"
+ ) : ( + + )} +
+
+ + +
+
+ {transcript.response.longSummary ? ( + - )} - {waveform.waveform && mp3.media ? ( - - ) : waveform.error ? ( -
"error loading this recording"
) : ( - - )} -
-
- - -
-
- {transcript.response.longSummary ? ( - +
+ {transcript.response.status == "processing" ? ( +

Loading Transcript

) : ( -
- {transcript.response.status == "processing" ? ( -

Loading Transcript

- ) : ( -

- There was an error generating the final summary, please - come back later -

- )} -
+

+ There was an error generating the final summary, please come + back later +

)} -
+
+ )} + -
-
- -
-
- -
-
+
+
+
-
- - )} +
+ +
+ +
+
); } diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts index 23249f94..58e0209d 100644 --- a/www/app/[domain]/transcripts/useMp3.ts +++ b/www/app/[domain]/transcripts/useMp3.ts @@ -16,26 +16,30 @@ const useMp3 = (id: string, waiting?: boolean): Mp3Response => { const api = getApi(true); const { api_url } = useContext(DomainContext); const accessTokenInfo = useFiefAccessTokenInfo(); - const [serviceWorkerReady, setServiceWorkerReady] = useState(false); + const [serviceWorker, setServiceWorker] = + useState(null); useEffect(() => { if ("serviceWorker" in navigator) { - navigator.serviceWorker.register("/service-worker.js").then(() => { - setServiceWorkerReady(true); + navigator.serviceWorker.register("/service-worker.js").then((worker) => { + setServiceWorker(worker); }); } + return () => { + serviceWorker?.unregister(); + }; }, []); useEffect(() => { if (!navigator.serviceWorker) return; if (!navigator.serviceWorker.controller) return; - if (!serviceWorkerReady) return; + if (!serviceWorker) return; // Send the token to the service worker navigator.serviceWorker.controller.postMessage({ type: "SET_AUTH_TOKEN", token: accessTokenInfo?.access_token, }); - }, [navigator.serviceWorker, serviceWorkerReady, accessTokenInfo]); + }, [navigator.serviceWorker, !serviceWorker, accessTokenInfo]); useEffect(() => { if (!id || !api || later) return; diff --git a/www/app/[domain]/transcripts/useWaveform.ts b/www/app/[domain]/transcripts/useWaveform.ts index 4073b711..d2bd0fd6 100644 --- a/www/app/[domain]/transcripts/useWaveform.ts +++ b/www/app/[domain]/transcripts/useWaveform.ts @@ -1,8 +1,5 @@ import { useEffect, useState } from "react"; -import { - DefaultApi, - V1TranscriptGetAudioWaveformRequest, -} from "../../api/apis/DefaultApi"; +import { V1TranscriptGetAudioWaveformRequest } from "../../api/apis/DefaultApi"; import { AudioWaveform } from "../../api"; import { useError } from "../../(errors)/errorContext"; import getApi from "../../lib/getApi"; diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index f289adbb..1e59781c 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -402,6 +402,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { console.debug("WebSocket connection closed"); switch (event.code) { case 1000: // Normal Closure: + case 1005: // Closure by client FF default: setError( new Error(`WebSocket closed unexpectedly with code: ${event.code}`), diff --git a/www/app/lib/edgeConfig.ts b/www/app/lib/edgeConfig.ts index 1140e555..5527121a 100644 --- a/www/app/lib/edgeConfig.ts +++ b/www/app/lib/edgeConfig.ts @@ -3,9 +3,9 @@ import { isDevelopment } from "./utils"; const localConfig = { features: { - requireLogin: false, + requireLogin: true, privacy: true, - browse: false, + browse: true, }, api_url: "http://127.0.0.1:1250", websocket_url: "ws://127.0.0.1:1250", diff --git a/www/app/lib/errorUtils.ts b/www/app/lib/errorUtils.ts index 81a39b5d..e9e5300d 100644 --- a/www/app/lib/errorUtils.ts +++ b/www/app/lib/errorUtils.ts @@ -1,5 +1,8 @@ function shouldShowError(error: Error | null | undefined) { - if (error?.name == "ResponseError" && error["response"].status == 404) + if ( + error?.name == "ResponseError" && + (error["response"].status == 404 || error["response"].status == 403) + ) return false; if (error?.name == "FetchError") return false; return true; diff --git a/www/app/lib/fief.ts b/www/app/lib/fief.ts index 02db67f5..176aa847 100644 --- a/www/app/lib/fief.ts +++ b/www/app/lib/fief.ts @@ -67,7 +67,7 @@ export const getFiefAuthMiddleware = async (url) => { parameters: {}, }, { - matcher: "/transcripts/((?!new).*)", + matcher: "/transcripts/((?!new))", parameters: {}, }, { From f14e6f5a7f9e137ff9d43cf87e25c3ecd32688d4 Mon Sep 17 00:00:00 2001 From: Sara Date: Wed, 22 Nov 2023 13:20:11 +0100 Subject: [PATCH 12/27] fix auth --- www/app/(auth)/fiefWrapper.tsx | 15 +++++++++++---- www/app/[domain]/layout.tsx | 5 ++++- .../[domain]/transcripts/[transcriptId]/page.tsx | 11 +++-------- .../transcripts/[transcriptId]/record/page.tsx | 6 +++--- www/app/[domain]/transcripts/createTranscript.ts | 2 +- www/app/[domain]/transcripts/finalSummary.tsx | 3 +-- www/app/[domain]/transcripts/shareLink.tsx | 3 +-- www/app/[domain]/transcripts/transcriptTitle.tsx | 3 +-- www/app/[domain]/transcripts/useMp3.ts | 2 +- www/app/[domain]/transcripts/useTopics.ts | 4 ++-- www/app/[domain]/transcripts/useTranscript.ts | 3 +-- www/app/[domain]/transcripts/useTranscriptList.ts | 2 +- www/app/[domain]/transcripts/useWaveform.ts | 4 ++-- www/app/[domain]/transcripts/useWebRTC.ts | 3 +-- www/app/lib/fief.ts | 4 ---- www/app/lib/getApi.ts | 8 +++++--- 16 files changed, 38 insertions(+), 40 deletions(-) diff --git a/www/app/(auth)/fiefWrapper.tsx b/www/app/(auth)/fiefWrapper.tsx index 187fef7c..bb38f5ee 100644 --- a/www/app/(auth)/fiefWrapper.tsx +++ b/www/app/(auth)/fiefWrapper.tsx @@ -1,11 +1,18 @@ "use client"; import { FiefAuthProvider } from "@fief/fief/nextjs/react"; +import { createContext } from "react"; -export default function FiefWrapper({ children }) { +export const CookieContext = createContext<{ hasAuthCookie: boolean }>({ + hasAuthCookie: false, +}); + +export default function FiefWrapper({ children, hasAuthCookie }) { return ( - - {children} - + + + {children} + + ); } diff --git a/www/app/[domain]/layout.tsx b/www/app/[domain]/layout.tsx index 3e881ac3..73cc4841 100644 --- a/www/app/[domain]/layout.tsx +++ b/www/app/[domain]/layout.tsx @@ -12,6 +12,8 @@ import Privacy from "../(aboutAndPrivacy)/privacy"; import { DomainContextProvider } from "./domainContext"; import { getConfig } from "../lib/edgeConfig"; import { ErrorBoundary } from "@sentry/nextjs"; +import { cookies } from "next/dist/client/components/headers"; +import { SESSION_COOKIE_NAME } from "../lib/fief"; const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] }); @@ -71,11 +73,12 @@ type LayoutProps = { export default async function RootLayout({ children, params }: LayoutProps) { const config = await getConfig(params.domain); const { requireLogin, privacy, browse } = config.features; + const hasAuthCookie = !!cookies().get(SESSION_COOKIE_NAME); return ( - + "something went really wrong"

}> diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 23634b95..472c573b 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -22,15 +22,13 @@ type TranscriptDetails = { }; }; -const protectedPath = false; - export default function TranscriptDetails(details: TranscriptDetails) { const transcriptId = details.params.transcriptId; const router = useRouter(); - const transcript = useTranscript(protectedPath, transcriptId); - const topics = useTopics(protectedPath, transcriptId); - const waveform = useWaveform(protectedPath, transcriptId); + const transcript = useTranscript(transcriptId); + const topics = useTopics(transcriptId); + const waveform = useWaveform(transcriptId); const useActiveTopic = useState(null); const mp3 = useMp3(transcriptId); @@ -71,7 +69,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
{transcript?.response?.title && ( @@ -101,7 +98,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
{transcript.response.longSummary ? (
{ } }, []); - const transcript = useTranscript(true, details.params.transcriptId); - const webRTC = useWebRTC(stream, details.params.transcriptId, true); + const transcript = useTranscript(details.params.transcriptId); + const webRTC = useWebRTC(stream, details.params.transcriptId); const webSockets = useWebSockets(details.params.transcriptId); const { audioDevices, getAudioStream } = useAudioDevice(); diff --git a/www/app/[domain]/transcripts/createTranscript.ts b/www/app/[domain]/transcripts/createTranscript.ts index 0d96b8db..9ad1abe0 100644 --- a/www/app/[domain]/transcripts/createTranscript.ts +++ b/www/app/[domain]/transcripts/createTranscript.ts @@ -19,7 +19,7 @@ const useCreateTranscript = (): CreateTranscript => { const [loading, setLoading] = useState(false); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(true); + const api = getApi(); const create = (params: V1TranscriptsCreateRequest["createTranscript"]) => { if (loading || !api) return; diff --git a/www/app/[domain]/transcripts/finalSummary.tsx b/www/app/[domain]/transcripts/finalSummary.tsx index e0d0f1c9..47c757bf 100644 --- a/www/app/[domain]/transcripts/finalSummary.tsx +++ b/www/app/[domain]/transcripts/finalSummary.tsx @@ -5,7 +5,6 @@ import "../../styles/markdown.css"; import getApi from "../../lib/getApi"; type FinalSummaryProps = { - protectedPath: boolean; summary: string; fullTranscript: string; transcriptId: string; @@ -18,7 +17,7 @@ export default function FinalSummary(props: FinalSummaryProps) { const [isEditMode, setIsEditMode] = useState(false); const [preEditSummary, setPreEditSummary] = useState(props.summary); const [editedSummary, setEditedSummary] = useState(props.summary); - const api = getApi(props.protectedPath); + const api = getApi(); const updateSummary = async (newSummary: string, transcriptId: string) => { if (!api) return; diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx index 82ef52c9..e6449bd3 100644 --- a/www/app/[domain]/transcripts/shareLink.tsx +++ b/www/app/[domain]/transcripts/shareLink.tsx @@ -8,7 +8,6 @@ import "../../styles/button.css"; import "../../styles/form.scss"; type ShareLinkProps = { - protectedPath: boolean; transcriptId: string; userId: string | null; shareMode: string; @@ -21,8 +20,8 @@ const ShareLink = (props: ShareLinkProps) => { const requireLogin = featureEnabled("requireLogin"); const [isOwner, setIsOwner] = useState(false); const [shareMode, setShareMode] = useState(props.shareMode); - const api = getApi(props.protectedPath); const userinfo = useFiefUserinfo(); + const api = getApi(); useEffect(() => { setCurrentUrl(window.location.href); diff --git a/www/app/[domain]/transcripts/transcriptTitle.tsx b/www/app/[domain]/transcripts/transcriptTitle.tsx index d2f901fa..afc29e51 100644 --- a/www/app/[domain]/transcripts/transcriptTitle.tsx +++ b/www/app/[domain]/transcripts/transcriptTitle.tsx @@ -2,7 +2,6 @@ import { useState } from "react"; import getApi from "../../lib/getApi"; type TranscriptTitle = { - protectedPath: boolean; title: string; transcriptId: string; }; @@ -11,7 +10,7 @@ const TranscriptTitle = (props: TranscriptTitle) => { const [displayedTitle, setDisplayedTitle] = useState(props.title); const [preEditTitle, setPreEditTitle] = useState(props.title); const [isEditing, setIsEditing] = useState(false); - const api = getApi(props.protectedPath); + const api = getApi(); const updateTitle = async (newTitle: string, transcriptId: string) => { if (!api) return; diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts index 58e0209d..363a4190 100644 --- a/www/app/[domain]/transcripts/useMp3.ts +++ b/www/app/[domain]/transcripts/useMp3.ts @@ -13,7 +13,7 @@ const useMp3 = (id: string, waiting?: boolean): Mp3Response => { const [media, setMedia] = useState(null); const [later, setLater] = useState(waiting); const [loading, setLoading] = useState(false); - const api = getApi(true); + const api = getApi(); const { api_url } = useContext(DomainContext); const accessTokenInfo = useFiefAccessTokenInfo(); const [serviceWorker, setServiceWorker] = diff --git a/www/app/[domain]/transcripts/useTopics.ts b/www/app/[domain]/transcripts/useTopics.ts index 01053019..de4097b3 100644 --- a/www/app/[domain]/transcripts/useTopics.ts +++ b/www/app/[domain]/transcripts/useTopics.ts @@ -14,12 +14,12 @@ type TranscriptTopics = { error: Error | null; }; -const useTopics = (protectedPath, id: string): TranscriptTopics => { +const useTopics = (id: string): TranscriptTopics => { const [topics, setTopics] = useState(null); const [loading, setLoading] = useState(false); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!id || !api) return; diff --git a/www/app/[domain]/transcripts/useTranscript.ts b/www/app/[domain]/transcripts/useTranscript.ts index 987e57f3..91700d7a 100644 --- a/www/app/[domain]/transcripts/useTranscript.ts +++ b/www/app/[domain]/transcripts/useTranscript.ts @@ -24,14 +24,13 @@ type SuccessTranscript = { }; const useTranscript = ( - protectedPath: boolean, id: string | null, ): ErrorTranscript | LoadingTranscript | SuccessTranscript => { const [response, setResponse] = useState(null); const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!id || !api) return; diff --git a/www/app/[domain]/transcripts/useTranscriptList.ts b/www/app/[domain]/transcripts/useTranscriptList.ts index cc8f4701..7b5abb37 100644 --- a/www/app/[domain]/transcripts/useTranscriptList.ts +++ b/www/app/[domain]/transcripts/useTranscriptList.ts @@ -15,7 +15,7 @@ const useTranscriptList = (page: number): TranscriptList => { const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(true); + const api = getApi(); useEffect(() => { if (!api) return; diff --git a/www/app/[domain]/transcripts/useWaveform.ts b/www/app/[domain]/transcripts/useWaveform.ts index d2bd0fd6..f80ad78c 100644 --- a/www/app/[domain]/transcripts/useWaveform.ts +++ b/www/app/[domain]/transcripts/useWaveform.ts @@ -11,12 +11,12 @@ type AudioWaveFormResponse = { error: Error | null; }; -const useWaveform = (protectedPath, id: string): AudioWaveFormResponse => { +const useWaveform = (id: string): AudioWaveFormResponse => { const [waveform, setWaveform] = useState(null); const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!id || !api) return; diff --git a/www/app/[domain]/transcripts/useWebRTC.ts b/www/app/[domain]/transcripts/useWebRTC.ts index f4421e4d..edd3bef0 100644 --- a/www/app/[domain]/transcripts/useWebRTC.ts +++ b/www/app/[domain]/transcripts/useWebRTC.ts @@ -10,11 +10,10 @@ import getApi from "../../lib/getApi"; const useWebRTC = ( stream: MediaStream | null, transcriptId: string | null, - protectedPath, ): Peer => { const [peer, setPeer] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!stream || !transcriptId) { diff --git a/www/app/lib/fief.ts b/www/app/lib/fief.ts index 176aa847..3af5c30f 100644 --- a/www/app/lib/fief.ts +++ b/www/app/lib/fief.ts @@ -66,10 +66,6 @@ export const getFiefAuthMiddleware = async (url) => { matcher: "/transcripts", parameters: {}, }, - { - matcher: "/transcripts/((?!new))", - parameters: {}, - }, { matcher: "/browse", parameters: {}, diff --git a/www/app/lib/getApi.ts b/www/app/lib/getApi.ts index 7392cc90..e1ece2a9 100644 --- a/www/app/lib/getApi.ts +++ b/www/app/lib/getApi.ts @@ -4,17 +4,19 @@ import { DefaultApi } from "../api/apis/DefaultApi"; import { useFiefAccessTokenInfo } from "@fief/fief/nextjs/react"; import { useContext, useEffect, useState } from "react"; import { DomainContext, featureEnabled } from "../[domain]/domainContext"; +import { CookieContext } from "../(auth)/fiefWrapper"; -export default function getApi(protectedPath: boolean): DefaultApi | undefined { +export default function getApi(): DefaultApi | undefined { const accessTokenInfo = useFiefAccessTokenInfo(); const api_url = useContext(DomainContext).api_url; const requireLogin = featureEnabled("requireLogin"); const [api, setApi] = useState(); + const { hasAuthCookie } = useContext(CookieContext); if (!api_url) throw new Error("no API URL"); useEffect(() => { - if (protectedPath && requireLogin && !accessTokenInfo) { + if (hasAuthCookie && requireLogin && !accessTokenInfo) { return; } @@ -25,7 +27,7 @@ export default function getApi(protectedPath: boolean): DefaultApi | undefined { : undefined, }); setApi(new DefaultApi(apiConfiguration)); - }, [!accessTokenInfo, protectedPath]); + }, [!accessTokenInfo, hasAuthCookie]); return api; } From 4226428f582821383e7fb1af06cfbadb919c1f7e Mon Sep 17 00:00:00 2001 From: Sara Date: Wed, 22 Nov 2023 13:53:12 +0100 Subject: [PATCH 13/27] minor styling changes --- .../transcripts/[transcriptId]/page.tsx | 2 ++ www/app/[domain]/transcripts/shareLink.tsx | 22 +++++++++++++++---- .../[domain]/transcripts/waveformLoading.tsx | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 472c573b..54eca9f9 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -15,6 +15,8 @@ import TranscriptTitle from "../transcriptTitle"; import Player from "../player"; import WaveformLoading from "../waveformLoading"; import { useRouter } from "next/navigation"; +import { faSpinner } from "@fortawesome/free-solid-svg-icons"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; type TranscriptDetails = { params: { diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx index e6449bd3..dd66d6cb 100644 --- a/www/app/[domain]/transcripts/shareLink.tsx +++ b/www/app/[domain]/transcripts/shareLink.tsx @@ -6,6 +6,8 @@ import SelectSearch from "react-select-search"; import "react-select-search/style.css"; import "../../styles/button.css"; import "../../styles/form.scss"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; +import { faSpinner } from "@fortawesome/free-solid-svg-icons"; type ShareLinkProps = { transcriptId: string; @@ -20,6 +22,7 @@ const ShareLink = (props: ShareLinkProps) => { const requireLogin = featureEnabled("requireLogin"); const [isOwner, setIsOwner] = useState(false); const [shareMode, setShareMode] = useState(props.shareMode); + const [shareLoading, setShareLoading] = useState(false); const userinfo = useFiefUserinfo(); const api = getApi(); @@ -46,6 +49,7 @@ const ShareLink = (props: ShareLinkProps) => { const updateShareMode = async (selectedShareMode: string) => { if (!api) return; + setShareLoading(true); const updatedTranscript = await api.v1TranscriptUpdate({ transcriptId: props.transcriptId, updateTranscript: { @@ -53,6 +57,7 @@ const ShareLink = (props: ShareLinkProps) => { }, }); setShareMode(updatedTranscript.shareMode); + setShareLoading(false); }; const privacyEnabled = featureEnabled("privacy"); @@ -62,7 +67,7 @@ const ShareLink = (props: ShareLinkProps) => { style={{ background: "rgba(96, 165, 250, 0.2)" }} > {requireLogin && ( -

+

{shareMode === "private" && (

This transcript is private and can only be accessed by you.

)} @@ -76,7 +81,7 @@ const ShareLink = (props: ShareLinkProps) => { )} {isOwner && api && ( -

+

{ ]} value={shareMode} onChange={updateShareMode} + closeOnSelect={true} /> -

+ {shareLoading && ( +
+ +
+ )} +
)} -

+
)} {!requireLogin && ( <> diff --git a/www/app/[domain]/transcripts/waveformLoading.tsx b/www/app/[domain]/transcripts/waveformLoading.tsx index 68e0c80f..56540927 100644 --- a/www/app/[domain]/transcripts/waveformLoading.tsx +++ b/www/app/[domain]/transcripts/waveformLoading.tsx @@ -5,7 +5,7 @@ export default () => (
); From aecc3a0c3bc7b4c6daf9474f8ef95a21726c22fc Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 15 Nov 2023 21:24:21 +0100 Subject: [PATCH 14/27] server: first attempts to split post pipeline as single celery tasks --- server/reflector/db/transcripts.py | 19 ++ .../reflector/pipelines/main_live_pipeline.py | 239 +++++++++++++++--- server/reflector/storage/base.py | 11 +- server/reflector/storage/storage_aws.py | 20 +- 4 files changed, 241 insertions(+), 48 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index f0dbc277..c0e8984b 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -106,6 +106,7 @@ class Transcript(BaseModel): events: list[TranscriptEvent] = [] source_language: str = "en" target_language: str = "en" + audio_location: str = "local" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -140,6 +141,10 @@ class Transcript(BaseModel): def audio_waveform_filename(self): return self.data_path / "audio.json" + @property + def storage_audio_path(self): + return f"{self.id}/audio.mp3" + @property def audio_waveform(self): try: @@ -283,5 +288,19 @@ class TranscriptController: transcript.upsert_topic(topic) await self.update(transcript, {"topics": transcript.topics_dump()}) + async def move_mp3_to_storage(self, transcript: Transcript): + """ + Move mp3 file to storage + """ + from reflector.storage import Storage + + storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) + await storage.put_file( + transcript.storage_audio_path, + self.audio_mp3_filename.read_bytes(), + ) + + await self.update(transcript, {"audio_location": "storage"}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 3a9d1868..e2f305c4 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -12,6 +12,7 @@ It is directly linked to our data model. """ import asyncio +import functools from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path @@ -55,6 +56,22 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager +from structlog import Logger + + +def asynctask(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + coro = f(*args, **kwargs) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + return loop.run_until_complete(coro) + return asyncio.run(coro) + + return wrapper def broadcast_to_sockets(func): @@ -75,6 +92,22 @@ def broadcast_to_sockets(func): return wrapper +def get_transcript(func): + """ + Decorator to fetch the transcript from the database from the first argument + """ + + async def wrapper(self, **kwargs): + transcript_id = kwargs.pop("transcript_id") + transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) + if not transcript: + raise Exception("Transcript {transcript_id} not found") + tlogger = logger.bind(transcript_id=transcript.id) + return await func(self, transcript=transcript, logger=tlogger, **kwargs) + + return wrapper + + class StrValue(BaseModel): value: str @@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]: + return [ + TitleSummaryWithIdProcessorType( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + @asynccontextmanager async def transaction(self): async with self._lock: @@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase): pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main live created", - transcript_id=self.transcript_id, - ) + pipeline.logger.info("Pipeline main live created") return pipeline @@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase): # when the pipeline ends, connect to the post pipeline logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) - task_pipeline_main_post.delay(transcript_id=self.transcript_id) + pipeline_post(transcript_id=self.transcript_id) class PipelineMainDiarization(PipelineMainBase): """ - Diarization is a long time process, so we do it in a separate pipeline - When done, adjust the short and final summary + Diarize the audio and update topics """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() - processors = [] - if settings.DIARIZATION_ENABLED: - processors += [ - AudioDiarizationAutoProcessor(callback=self.on_topic), - ] - - processors += [ - BroadcastProcessor( - processors=[ - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), - ] - ), - ] - pipeline = Pipeline(*processors) + pipeline = Pipeline( + AudioDiarizationAutoProcessor(callback=self.on_topic), + ) pipeline.options = self # now let's start the pipeline by pushing information to the # first processor diarization processor # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() - topics = [ - TitleSummaryWithIdProcessorType( - id=topic.id, - title=topic.title, - summary=topic.summary, - timestamp=topic.timestamp, - duration=topic.duration, - transcript=TranscriptProcessorType(words=topic.words), - ) - for topic in transcript.topics - ] + topics = self.get_transcript_topics(transcript) # we need to create an url to be used for diarization # we can't use the audio_mp3_filename because it's not accessible @@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase): # as tempting to use pipeline.push, prefer to use the runner # to let the start just do one job. pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main post created", transcript_id=self.transcript_id - ) + pipeline.logger.info("Diarization pipeline created") self.push(audio_diarization_input) self.flush() return pipeline +class PipelineMainSummaries(PipelineMainBase): + """ + Generate summaries from the topics + """ + + async def create(self) -> Pipeline: + self.prepare() + pipeline = Pipeline( + BroadcastProcessor( + processors=[ + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ), + ) + pipeline.options = self + + # get transcript + transcript = await self.get_transcript() + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info("Summaries pipeline created") + + # push topics + topics = await self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + @shared_task def task_pipeline_main_post(transcript_id: str): logger.info( @@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str): ) runner = PipelineMainDiarization(transcript_id=transcript_id) runner.start_sync() + + +@get_transcript +async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting convert to mp3") + + # If the audio wav is not available, just skip + wav_filename = transcript.audio_wav_filename + if not wav_filename.exists(): + logger.warning("Wav file not found, may be already converted") + return + + # Convert to mp3 + mp3_filename = transcript.audio_mp3_filename + + import av + + input_container = av.open(wav_filename) + output_container = av.open(mp3_filename, "w") + input_audio_stream = input_container.streams.audio[0] + output_audio_stream = output_container.add_stream("mp3") + output_audio_stream.codec_context.set_parameters( + input_audio_stream.codec_context.parameters + ) + for packet in input_container.demux(input_audio_stream): + for frame in packet.decode(): + output_container.mux(frame) + input_container.close() + output_container.close() + + logger.info("Convert to mp3 done") + + +@get_transcript +async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting upload mp3") + + # If the audio mp3 is not available, just skip + mp3_filename = transcript.audio_mp3_filename + if not mp3_filename.exists(): + logger.warning("Mp3 file not found, may be already uploaded") + return + + # Upload to external storage and delete the file + await transcripts_controller.move_to_storage(transcript) + await transcripts_controller.unlink_mp3(transcript) + + logger.info("Upload mp3 done") + + +@get_transcript +@asynctask +async def pipeline_diarization(transcript: Transcript, logger: Logger): + logger.info("Starting diarization") + runner = PipelineMainDiarization(transcript_id=transcript.id) + await runner.start() + logger.info("Diarization done") + + +@get_transcript +@asynctask +async def pipeline_summaries(transcript: Transcript, logger: Logger): + logger.info("Starting summaries") + runner = PipelineMainSummaries(transcript_id=transcript.id) + await runner.start() + logger.info("Summaries done") + + +# =================================================================== +# Celery tasks that can be called from the API +# =================================================================== + + +@shared_task +@asynctask +async def task_pipeline_convert_to_mp3(transcript_id: str): + await pipeline_convert_to_mp3(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_upload_mp3(transcript_id: str): + await pipeline_upload_mp3(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_diarization(transcript_id: str): + await pipeline_diarization(transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_summaries(transcript_id: str): + await pipeline_summaries(transcript_id) + + +def pipeline_post(transcript_id: str): + """ + Run the post pipeline + """ + chain_mp3_and_diarize = ( + task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + | task_pipeline_upload_mp3.si(transcript_id=transcript_id) + | task_pipeline_diarization.si(transcript_id=transcript_id) + ) + chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) + chain = chain_mp3_and_diarize | chain_summary + chain.delay() diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 5cdafdbf..7c44ff4d 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -1,6 +1,7 @@ +import importlib + from pydantic import BaseModel from reflector.settings import settings -import importlib class FileResult(BaseModel): @@ -17,14 +18,14 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name, settings_prefix=""): + def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) # gather specific configuration for the processor # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` - config = {} + config = {"folder": folder} name_upper = name.upper() config_prefix = f"{settings_prefix}{name_upper}_" for key, value in settings: @@ -34,6 +35,10 @@ class Storage: return cls._registry[name](**config) + def __init__(self): + self.folder = "" + super().__init__() + async def put_file(self, filename: str, data: bytes) -> FileResult: return await self._put_file(filename, data) diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 09a9c383..5ab02903 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -1,6 +1,6 @@ import aioboto3 -from reflector.storage.base import Storage, FileResult from reflector.logger import logger +from reflector.storage.base import FileResult, Storage class AwsStorage(Storage): @@ -22,9 +22,14 @@ class AwsStorage(Storage): super().__init__() self.aws_bucket_name = aws_bucket_name - self.aws_folder = "" + folder = "" if "/" in aws_bucket_name: - self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) + self.aws_bucket_name, folder = aws_bucket_name.split("/", 1) + if folder: + if not self.folder: + self.folder = folder + else: + self.folder = f"{self.folder}/{folder}" self.session = aioboto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -34,7 +39,7 @@ class AwsStorage(Storage): async def _put_file(self, filename: str, data: bytes) -> FileResult: bucket = self.aws_bucket_name - folder = self.aws_folder + folder = self.folder logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: @@ -44,6 +49,11 @@ class AwsStorage(Storage): Body=data, ) + async def get_file_url(self, filename: str) -> FileResult: + bucket = self.aws_bucket_name + folder = self.folder + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( "get_object", Params={"Bucket": bucket, "Key": s3filename}, @@ -57,7 +67,7 @@ class AwsStorage(Storage): async def _delete_file(self, filename: str): bucket = self.aws_bucket_name - folder = self.aws_folder + folder = self.folder logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: From 88f443e2c25cd8bf9158b1c83bb8c76a3813d843 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 16 Nov 2023 14:32:18 +0100 Subject: [PATCH 15/27] server: revert change on storage folder --- server/reflector/storage/base.py | 14 ++++++++------ server/reflector/storage/storage_aws.py | 22 +++++++--------------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 7c44ff4d..a457ddf8 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -18,14 +18,14 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""): + def get_instance(cls, name: str, settings_prefix: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) # gather specific configuration for the processor # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` - config = {"folder": folder} + config = {} name_upper = name.upper() config_prefix = f"{settings_prefix}{name_upper}_" for key, value in settings: @@ -35,10 +35,6 @@ class Storage: return cls._registry[name](**config) - def __init__(self): - self.folder = "" - super().__init__() - async def put_file(self, filename: str, data: bytes) -> FileResult: return await self._put_file(filename, data) @@ -50,3 +46,9 @@ class Storage: async def _delete_file(self, filename: str): raise NotImplementedError + + async def get_file_url(self, filename: str) -> str: + return await self._get_file_url(filename) + + async def _get_file_url(self, filename: str) -> str: + raise NotImplementedError diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 5ab02903..d2313293 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -22,14 +22,9 @@ class AwsStorage(Storage): super().__init__() self.aws_bucket_name = aws_bucket_name - folder = "" + self.aws_folder = "" if "/" in aws_bucket_name: - self.aws_bucket_name, folder = aws_bucket_name.split("/", 1) - if folder: - if not self.folder: - self.folder = folder - else: - self.folder = f"{self.folder}/{folder}" + self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) self.session = aioboto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -39,7 +34,7 @@ class AwsStorage(Storage): async def _put_file(self, filename: str, data: bytes) -> FileResult: bucket = self.aws_bucket_name - folder = self.folder + folder = self.aws_folder logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: @@ -49,9 +44,9 @@ class AwsStorage(Storage): Body=data, ) - async def get_file_url(self, filename: str) -> FileResult: + async def _get_file_url(self, filename: str) -> FileResult: bucket = self.aws_bucket_name - folder = self.folder + folder = self.aws_folder s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( @@ -60,14 +55,11 @@ class AwsStorage(Storage): ExpiresIn=3600, ) - return FileResult( - filename=filename, - url=presigned_url, - ) + return presigned_url async def _delete_file(self, filename: str): bucket = self.aws_bucket_name - folder = self.folder + folder = self.aws_folder logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") s3filename = f"{folder}/{filename}" if folder else filename async with self.session.client("s3") as client: From 06b29d9bd4a1d1b7f6265756be57c5602888673a Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 16 Nov 2023 14:34:33 +0100 Subject: [PATCH 16/27] server: add audio_location and move to external storage if possible --- .../versions/f819277e5169_audio_location.py | 43 ++++ server/reflector/db/transcripts.py | 67 ++++- .../reflector/pipelines/main_live_pipeline.py | 234 ++++++++++-------- server/reflector/pipelines/runner.py | 3 +- server/reflector/settings.py | 5 +- 5 files changed, 238 insertions(+), 114 deletions(-) create mode 100644 server/migrations/versions/f819277e5169_audio_location.py diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py new file mode 100644 index 00000000..576b02bd --- /dev/null +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -0,0 +1,43 @@ +"""audio_location + +Revision ID: f819277e5169 +Revises: 4814901632bc +Create Date: 2023-11-16 10:29:09.351664 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f819277e5169" +down_revision: Union[str, None] = "4814901632bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "audio_location", sa.String(), server_default="local", nullable=False + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "share_mode", + sa.VARCHAR(), + server_default=sa.text("'private'"), + nullable=False, + ), + ) + # ### end Alembic commands ### diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index c0e8984b..44a6d56b 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings +from reflector.storage import Storage transcripts = sqlalchemy.Table( "transcript", @@ -27,20 +28,33 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column( + "audio_location", + sqlalchemy.String, + nullable=False, + server_default="local", + ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), ) -def generate_uuid4(): +def generate_uuid4() -> str: return str(uuid4()) -def generate_transcript_name(): +def generate_transcript_name() -> str: now = datetime.utcnow() return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +def get_storage() -> Storage: + return Storage.get_instance( + name=settings.TRANSCRIPT_STORAGE_BACKEND, + settings_prefix="TRANSCRIPT_STORAGE_", + ) + + class AudioWaveform(BaseModel): data: list[float] @@ -133,6 +147,10 @@ class Transcript(BaseModel): def data_path(self): return Path(settings.DATA_DIR) / self.id + @property + def audio_wav_filename(self): + return self.data_path / "audio.wav" + @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @@ -157,6 +175,40 @@ class Transcript(BaseModel): return AudioWaveform(data=data) + async def get_audio_url(self) -> str: + if self.audio_location == "local": + return self._generate_local_audio_link() + elif self.audio_location == "storage": + return await self._generate_storage_audio_link() + raise Exception(f"Unknown audio location {self.audio_location}") + + async def _generate_storage_audio_link(self) -> str: + return await get_storage().get_file_url(self.storage_audio_path) + + def _generate_local_audio_link(self) -> str: + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from datetime import timedelta + + from reflector.app import app + from reflector.views.transcripts import create_access_token + + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=self.id, + ) + url = f"{settings.BASE_URL}{path}" + if self.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": self.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" + return url + class TranscriptController: async def get_all( @@ -292,15 +344,18 @@ class TranscriptController: """ Move mp3 file to storage """ - from reflector.storage import Storage - storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE) - await storage.put_file( + # store the audio on external storage + await get_storage().put_file( transcript.storage_audio_path, - self.audio_mp3_filename.read_bytes(), + transcript.audio_mp3_filename.read_bytes(), ) + # indicate on the transcript that the audio is now on storage await self.update(transcript, {"audio_location": "storage"}) + # unlink the local file + transcript.audio_mp3_filename.unlink(missing_ok=True) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index e2f305c4..83b57949 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -14,12 +14,9 @@ It is directly linked to our data model. import asyncio import functools from contextlib import asynccontextmanager -from datetime import timedelta -from pathlib import Path -from celery import shared_task +from celery import chord, group, shared_task from pydantic import BaseModel -from reflector.app import app from reflector.db.transcripts import ( Transcript, TranscriptDuration, @@ -56,7 +53,7 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager -from structlog import Logger +from structlog import BoundLogger as Logger def asynctask(f): @@ -97,13 +94,17 @@ def get_transcript(func): Decorator to fetch the transcript from the database from the first argument """ - async def wrapper(self, **kwargs): + async def wrapper(**kwargs): transcript_id = kwargs.pop("transcript_id") transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) if not transcript: raise Exception("Transcript {transcript_id} not found") tlogger = logger.bind(transcript_id=transcript.id) - return await func(self, transcript=transcript, logger=tlogger, **kwargs) + try: + return await func(transcript=transcript, logger=tlogger, **kwargs) + except Exception as exc: + tlogger.error("Pipeline error", exc_info=exc) + raise return wrapper @@ -162,7 +163,7 @@ class PipelineMainBase(PipelineRunner): "flush": "processing", "error": "error", } - elif isinstance(self, PipelineMainDiarization): + elif isinstance(self, PipelineMainFinalSummaries): status_mapping = { "push": "processing", "flush": "processing", @@ -170,7 +171,8 @@ class PipelineMainBase(PipelineRunner): "ended": "ended", } else: - raise Exception(f"Runner {self.__class__} is missing status mapping") + # intermediate pipeline don't update status + return # mutate to model status status = status_mapping.get(status) @@ -308,9 +310,10 @@ class PipelineMainBase(PipelineRunner): class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + """ + Main pipeline for live streaming, attach to RTC connection + Any long post process should be done in the post pipeline + """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction @@ -320,7 +323,7 @@ class PipelineMainLive(PipelineMainBase): processors = [ AudioFileWriterProcessor( - path=transcript.audio_mp3_filename, + path=transcript.audio_wav_filename, on_duration=self.on_duration, ), AudioChunkerProcessor(), @@ -329,15 +332,11 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), - ] + # XXX move as a task + AudioWaveformProcessor.as_threaded( + audio_path=transcript.audio_mp3_filename, + waveform_path=transcript.audio_waveform_filename, + on_waveform=self.on_waveform, ), ] pipeline = Pipeline(*processors) @@ -374,28 +373,16 @@ class PipelineMainDiarization(PipelineMainBase): # first processor diarization processor # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() + + # diarization works only if the file is uploaded to an external storage + if transcript.audio_location == "local": + pipeline.logger.info("Audio is local, skipping diarization") + return + topics = self.get_transcript_topics(transcript) - - # we need to create an url to be used for diarization - # we can't use the audio_mp3_filename because it's not accessible - # from the diarization processor - from reflector.views.transcripts import create_access_token - - path = app.url_path_for( - "transcript_get_audio_mp3", - transcript_id=transcript.id, - ) - url = f"{settings.BASE_URL}{path}" - if transcript.user_id: - # we pass token only if the user_id is set - # otherwise, the audio is public - token = create_access_token( - {"sub": transcript.user_id}, - expires_delta=timedelta(minutes=15), - ) - url += f"?token={token}" + audio_url = await transcript.get_audio_url() audio_diarization_input = AudioDiarizationInput( - audio_url=url, + audio_url=audio_url, topics=topics, ) @@ -409,14 +396,60 @@ class PipelineMainDiarization(PipelineMainBase): return pipeline -class PipelineMainSummaries(PipelineMainBase): +class PipelineMainFromTopics(PipelineMainBase): + """ + Pseudo class for generating a pipeline from topics + """ + + def get_processors(self) -> list: + raise NotImplementedError + + async def create(self) -> Pipeline: + self.prepare() + processors = self.get_processors() + pipeline = Pipeline(*processors) + pipeline.options = self + + # get transcript + transcript = await self.get_transcript() + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info(f"{self.__class__.__name__} pipeline created") + + # push topics + topics = self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + +class PipelineMainTitleAndShortSummary(PipelineMainFromTopics): + """ + Generate title from the topics + """ + + def get_processors(self) -> list: + return [ + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ) + ] + + +class PipelineMainFinalSummaries(PipelineMainFromTopics): """ Generate summaries from the topics """ - async def create(self) -> Pipeline: - self.prepare() - pipeline = Pipeline( + def get_processors(self) -> list: + return [ BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -427,32 +460,7 @@ class PipelineMainSummaries(PipelineMainBase): ), ] ), - ) - pipeline.options = self - - # get transcript - transcript = await self.get_transcript() - pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info("Summaries pipeline created") - - # push topics - topics = await self.get_transcript_topics(transcript) - for topic in topics: - self.push(topic) - - self.flush() - - return pipeline - - -@shared_task -def task_pipeline_main_post(transcript_id: str): - logger.info( - "Starting main post pipeline", - transcript_id=transcript_id, - ) - runner = PipelineMainDiarization(transcript_id=transcript_id) - runner.start_sync() + ] @get_transcript @@ -470,24 +478,26 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): import av - input_container = av.open(wav_filename) - output_container = av.open(mp3_filename, "w") - input_audio_stream = input_container.streams.audio[0] - output_audio_stream = output_container.add_stream("mp3") - output_audio_stream.codec_context.set_parameters( - input_audio_stream.codec_context.parameters - ) - for packet in input_container.demux(input_audio_stream): - for frame in packet.decode(): - output_container.mux(frame) - input_container.close() - output_container.close() + with av.open(wav_filename.as_posix()) as in_container: + in_stream = in_container.streams.audio[0] + with av.open(mp3_filename.as_posix(), "w") as out_container: + out_stream = out_container.add_stream("mp3") + for frame in in_container.decode(in_stream): + for packet in out_stream.encode(frame): + out_container.mux(packet) + + # Delete the wav file + transcript.audio_wav_filename.unlink(missing_ok=True) logger.info("Convert to mp3 done") @get_transcript async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + if not settings.TRANSCRIPT_STORAGE_BACKEND: + logger.info("No storage backend configured, skipping mp3 upload") + return + logger.info("Starting upload mp3") # If the audio mp3 is not available, just skip @@ -497,27 +507,32 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): return # Upload to external storage and delete the file - await transcripts_controller.move_to_storage(transcript) - await transcripts_controller.unlink_mp3(transcript) + await transcripts_controller.move_mp3_to_storage(transcript) logger.info("Upload mp3 done") @get_transcript -@asynctask async def pipeline_diarization(transcript: Transcript, logger: Logger): logger.info("Starting diarization") runner = PipelineMainDiarization(transcript_id=transcript.id) - await runner.start() + await runner.run() logger.info("Diarization done") @get_transcript -@asynctask +async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger): + logger.info("Starting title and short summary") + runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id) + await runner.run() + logger.info("Title and short summary done") + + +@get_transcript async def pipeline_summaries(transcript: Transcript, logger: Logger): logger.info("Starting summaries") - runner = PipelineMainSummaries(transcript_id=transcript.id) - await runner.start() + runner = PipelineMainFinalSummaries(transcript_id=transcript.id) + await runner.run() logger.info("Summaries done") @@ -528,29 +543,35 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): @shared_task @asynctask -async def task_pipeline_convert_to_mp3(transcript_id: str): - await pipeline_convert_to_mp3(transcript_id) +async def task_pipeline_convert_to_mp3(*, transcript_id: str): + await pipeline_convert_to_mp3(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_upload_mp3(transcript_id: str): - await pipeline_upload_mp3(transcript_id) +async def task_pipeline_upload_mp3(*, transcript_id: str): + await pipeline_upload_mp3(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_diarization(transcript_id: str): - await pipeline_diarization(transcript_id) +async def task_pipeline_diarization(*, transcript_id: str): + await pipeline_diarization(transcript_id=transcript_id) @shared_task @asynctask -async def task_pipeline_summaries(transcript_id: str): - await pipeline_summaries(transcript_id) +async def task_pipeline_title_and_short_summary(*, transcript_id: str): + await pipeline_title_and_short_summary(transcript_id=transcript_id) -def pipeline_post(transcript_id: str): +@shared_task +@asynctask +async def task_pipeline_final_summaries(*, transcript_id: str): + await pipeline_summaries(transcript_id=transcript_id) + + +def pipeline_post(*, transcript_id: str): """ Run the post pipeline """ @@ -559,6 +580,15 @@ def pipeline_post(transcript_id: str): | task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) - chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id) - chain = chain_mp3_and_diarize | chain_summary + chain_title_preview = task_pipeline_title_and_short_summary.si( + transcript_id=transcript_id + ) + chain_final_summaries = task_pipeline_final_summaries.si( + transcript_id=transcript_id + ) + + chain = chord( + group(chain_mp3_and_diarize, chain_title_preview), + chain_final_summaries, + ) chain.delay() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index a1e137a7..4105d51f 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -119,8 +119,7 @@ class PipelineRunner(BaseModel): self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() - if self.on_ended: - await self.on_ended() + raise async def cmd_push(self, data): if self._is_first_push: diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 65412310..2c68c4e5 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -54,7 +54,7 @@ class Settings(BaseSettings): TRANSCRIPT_MODAL_API_KEY: str | None = None # Audio transcription storage - TRANSCRIPT_STORAGE_BACKEND: str = "aws" + TRANSCRIPT_STORAGE_BACKEND: str | None = None # Storage configuration for AWS TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" @@ -62,9 +62,6 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None - # Transcript MP3 storage - TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws" - # LLM # available backend: openai, modal, oobabooga LLM_BACKEND: str = "oobabooga" From 5ffa931822ff5b7c5205f942ca2332f52b9ef892 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 16 Nov 2023 14:45:40 +0100 Subject: [PATCH 17/27] server: update backend tests results (rpc does not work with chords) --- .../migrations/versions/f819277e5169_audio_location.py | 10 +--------- server/tests/conftest.py | 8 +++++++- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py index 576b02bd..061abec4 100644 --- a/server/migrations/versions/f819277e5169_audio_location.py +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -31,13 +31,5 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "transcript", - sa.Column( - "share_mode", - sa.VARCHAR(), - server_default=sa.text("'private'"), - nullable=False, - ), - ) + op.drop_column("transcript", "audio_location") # ### end Alembic commands ### diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aafca9fd..aaf42884 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -133,4 +133,10 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - return {"broker_url": "memory://", "result_backend": "rpc"} + import tempfile + + with tempfile.NamedTemporaryFile() as fd: + yield { + "broker_url": "memory://", + "result_backend": "db+sqlite://" + fd.name, + } From 99b973f36f5fbbe1c193dd00dc2233fc82eddb4d Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 17 Nov 2023 14:27:53 +0100 Subject: [PATCH 18/27] server: fix tests --- server/reflector/pipelines/runner.py | 8 ++++++ server/tests/conftest.py | 36 +++++++++++++++++++++---- server/tests/test_transcripts_rtc_ws.py | 4 +++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 4105d51f..708a4265 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -106,6 +106,14 @@ class PipelineRunner(BaseModel): if not self.pipeline: self.pipeline = await self.create() + if not self.pipeline: + # no pipeline created in create, just finish it then. + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + return + # start the loop await self._set_status("started") while not self._ev_done.is_set(): diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aaf42884..532ebff9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from tempfile import NamedTemporaryFile import pytest @@ -7,7 +8,6 @@ import pytest @pytest.mark.asyncio async def setup_database(): from reflector.settings import settings - from tempfile import NamedTemporaryFile with NamedTemporaryFile() as f: settings.DATABASE_URL = f"sqlite:///{f.name}" @@ -103,6 +103,25 @@ async def dummy_llm(): yield +@pytest.fixture +async def dummy_storage(): + from reflector.storage.base import Storage + + class DummyStorage(Storage): + async def _put_file(self, *args, **kwargs): + pass + + async def _delete_file(self, *args, **kwargs): + pass + + async def _get_file_url(self, *args, **kwargs): + return "http://fake_server/audio.mp3" + + with patch("reflector.storage.base.Storage.get_instance") as mock_storage: + mock_storage.return_value = DummyStorage() + yield + + @pytest.fixture def nltk(): with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: @@ -133,10 +152,17 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - import tempfile - - with tempfile.NamedTemporaryFile() as fd: + with NamedTemporaryFile() as f: yield { "broker_url": "memory://", - "result_backend": "db+sqlite://" + fd.name, + "result_backend": f"db+sqlite:///{f.name}", } + + +@pytest.fixture(scope="session") +def fake_mp3_upload(): + with patch( + "reflector.db.transcripts.TranscriptController.move_mp3_to_storage" + ) as mock_move: + mock_move.return_value = True + yield diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index b33b1db5..8502a0d9 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, @@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, From 794d08c3a88d55a2ac6ea5faecd117697d838612 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 21 Nov 2023 14:46:16 +0100 Subject: [PATCH 19/27] server: redirect to storage url --- server/reflector/views/transcripts.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 6909b8ae..7496b26c 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -233,6 +233,12 @@ async def transcript_get_audio_mp3( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") + if transcript.audio_location == "storage": + url = transcript.get_audio_url() + from fastapi.responses import RedirectResponse + + return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND) + if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") From 0e5c0f66d91fa61f7bfee283a89f96f13dc57fef Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 21 Nov 2023 15:40:49 +0100 Subject: [PATCH 20/27] server: move waveform out of the live pipeline --- .../reflector/pipelines/main_live_pipeline.py | 46 +++++++++++++++---- server/reflector/views/transcripts.py | 22 +++++++-- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 83b57949..b182f421 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -332,12 +332,6 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - # XXX move as a task - AudioWaveformProcessor.as_threaded( - audio_path=transcript.audio_mp3_filename, - waveform_path=transcript.audio_waveform_filename, - on_waveform=self.on_waveform, - ), ] pipeline = Pipeline(*processors) pipeline.options = self @@ -406,12 +400,14 @@ class PipelineMainFromTopics(PipelineMainBase): async def create(self) -> Pipeline: self.prepare() + + # get transcript + self._transcript = transcript = await self.get_transcript() + + # create pipeline processors = self.get_processors() pipeline = Pipeline(*processors) pipeline.options = self - - # get transcript - transcript = await self.get_transcript() pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.info(f"{self.__class__.__name__} pipeline created") @@ -463,6 +459,29 @@ class PipelineMainFinalSummaries(PipelineMainFromTopics): ] +class PipelineMainWaveform(PipelineMainFromTopics): + """ + Generate waveform + """ + + def get_processors(self) -> list: + return [ + AudioWaveformProcessor.as_threaded( + audio_path=self._transcript.audio_wav_filename, + waveform_path=self._transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), + ] + + +@get_transcript +async def pipeline_waveform(transcript: Transcript, logger: Logger): + logger.info("Starting waveform") + runner = PipelineMainWaveform(transcript_id=transcript.id) + await runner.run() + logger.info("Waveform done") + + @get_transcript async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): logger.info("Starting convert to mp3") @@ -541,6 +560,12 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): # =================================================================== +@shared_task +@asynctask +async def task_pipeline_waveform(*, transcript_id: str): + await pipeline_waveform(transcript_id=transcript_id) + + @shared_task @asynctask async def task_pipeline_convert_to_mp3(*, transcript_id: str): @@ -576,7 +601,8 @@ def pipeline_post(*, transcript_id: str): Run the post pipeline """ chain_mp3_and_diarize = ( - task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + task_pipeline_waveform.si(transcript_id=transcript_id) + | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) | task_pipeline_upload_mp3.si(transcript_id=transcript_id) | task_pipeline_diarization.si(transcript_id=transcript_id) ) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 7496b26c..125aa311 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,12 +1,14 @@ from datetime import datetime, timedelta from typing import Annotated, Optional +import httpx import reflector.auth as auth from fastapi import ( APIRouter, Depends, HTTPException, Request, + Response, WebSocket, WebSocketDisconnect, status, @@ -234,10 +236,22 @@ async def transcript_get_audio_mp3( raise HTTPException(status_code=404, detail="Transcript not found") if transcript.audio_location == "storage": - url = transcript.get_audio_url() - from fastapi.responses import RedirectResponse + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} - return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND) + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") @@ -263,7 +277,7 @@ async def transcript_get_audio_waveform( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - if not transcript.audio_mp3_filename.exists(): + if not transcript.audio_waveform_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") return transcript.audio_waveform From f8407874f77172c3aafa379f5089e80b00533064 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 23 Nov 2023 12:41:39 +0100 Subject: [PATCH 21/27] server: fixes share_mode script --- .../versions/0fea6d96b096_add_share_mode.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/server/migrations/versions/0fea6d96b096_add_share_mode.py b/server/migrations/versions/0fea6d96b096_add_share_mode.py index 52a72d48..48746c3b 100644 --- a/server/migrations/versions/0fea6d96b096_add_share_mode.py +++ b/server/migrations/versions/0fea6d96b096_add_share_mode.py @@ -1,7 +1,7 @@ """add share_mode Revision ID: 0fea6d96b096 -Revises: 38a927dcb099 +Revises: f819277e5169 Create Date: 2023-11-07 11:12:21.614198 """ @@ -12,19 +12,22 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = '0fea6d96b096' -down_revision: Union[str, None] = '38a927dcb099' +revision: str = "0fea6d96b096" +down_revision: Union[str, None] = "f819277e5169" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('transcript', sa.Column('share_mode', sa.String(), server_default='private', nullable=False)) + op.add_column( + "transcript", + sa.Column("share_mode", sa.String(), server_default="private", nullable=False), + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('transcript', 'share_mode') + op.drop_column("transcript", "share_mode") # ### end Alembic commands ### From 3ebb21923bffa319ba9ec0c771de3623ccbeb324 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 29 Nov 2023 20:34:43 +0100 Subject: [PATCH 22/27] server: enhance diarization algorithm --- .../reflector/processors/audio_diarization.py | 161 +++++++++++++++++- .../tests/test_processor_audio_diarization.py | 140 +++++++++++++++ 2 files changed, 294 insertions(+), 7 deletions(-) create mode 100644 server/tests/test_processor_audio_diarization.py diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py index 82c6a553..69eab5b7 100644 --- a/server/reflector/processors/audio_diarization.py +++ b/server/reflector/processors/audio_diarization.py @@ -1,5 +1,5 @@ from reflector.processors.base import Processor -from reflector.processors.types import AudioDiarizationInput, TitleSummary +from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word class AudioDiarizationProcessor(Processor): @@ -19,12 +19,12 @@ class AudioDiarizationProcessor(Processor): # topics is a list[BaseModel] with an attribute words # words is a list[BaseModel] with text, start and speaker attribute - # mutate in place - for topic in data.topics: - for word in topic.transcript.words: - for d in diarization: - if d["start"] <= word.start <= d["end"]: - word.speaker = d["speaker"] + # create a view of words based on topics + # the current algorithm is using words index, we cannot use a generator + words = list(self.iter_words_from_topics(data.topics)) + + # assign speaker to words (mutate the words list) + self.assign_speaker(words, diarization) # emit them for topic in data.topics: @@ -32,3 +32,150 @@ class AudioDiarizationProcessor(Processor): async def _diarize(self, data: AudioDiarizationInput): raise NotImplementedError + + def assign_speaker(self, words: list[Word], diarization: list[dict]): + self._diarization_remove_overlap(diarization) + self._diarization_remove_segment_without_words(words, diarization) + self._diarization_merge_same_speaker(words, diarization) + self._diarization_assign_speaker(words, diarization) + + def iter_words_from_topics(self, topics: TitleSummary): + for topic in topics: + for word in topic.transcript.words: + yield word + + def is_word_continuation(self, word_prev, word): + """ + Return True if the word is a continuation of the previous word + by checking if the previous word is ending with a punctuation + or if the current word is starting with a capital letter + """ + # is word_prev ending with a punctuation ? + if word_prev.text and word_prev.text[-1] in ".?!": + return False + elif word.text and word.text[0].isupper(): + return False + return True + + def _diarization_remove_overlap(self, diarization: list[dict]): + """ + Remove overlap in diarization results + + When using a diarization algorithm, it's possible to have overlapping segments + This function remove the overlap by keeping the longest segment + + Warning: this function mutate the diarization list + """ + # remove overlap by keeping the longest segment + diarization_idx = 0 + while diarization_idx < len(diarization) - 1: + d = diarization[diarization_idx] + dnext = diarization[diarization_idx + 1] + if d["end"] > dnext["start"]: + # remove the shortest segment + if d["end"] - d["start"] > dnext["end"] - dnext["start"]: + # remove next segment + diarization.pop(diarization_idx + 1) + else: + # remove current segment + diarization.pop(diarization_idx) + else: + diarization_idx += 1 + + def _diarization_remove_segment_without_words( + self, words: list[Word], diarization: list[dict] + ): + """ + Remove diarization segments without words + + Warning: this function mutate the diarization list + """ + # count the number of words for each diarization segment + diarization_count = [] + for d in diarization: + start = d["start"] + end = d["end"] + count = 0 + for word in words: + if start <= word.start < end: + count += 1 + elif start < word.end <= end: + count += 1 + diarization_count.append(count) + + # remove diarization segments with no words + diarization_idx = 0 + while diarization_idx < len(diarization): + if diarization_count[diarization_idx] == 0: + diarization.pop(diarization_idx) + diarization_count.pop(diarization_idx) + else: + diarization_idx += 1 + + def _diarization_merge_same_speaker( + self, words: list[Word], diarization: list[dict] + ): + """ + Merge diarization contigous segments with the same speaker + + Warning: this function mutate the diarization list + """ + # merge segment with same speaker + diarization_idx = 0 + while diarization_idx < len(diarization) - 1: + d = diarization[diarization_idx] + dnext = diarization[diarization_idx + 1] + if d["speaker"] == dnext["speaker"]: + diarization[diarization_idx]["end"] = dnext["end"] + diarization.pop(diarization_idx + 1) + else: + diarization_idx += 1 + + def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]): + """ + Assign speaker to words based on diarization + + Warning: this function mutate the words list + """ + + word_idx = 0 + last_speaker = None + for d in diarization: + start = d["start"] + end = d["end"] + speaker = d["speaker"] + + # diarization may start after the first set of words + # in this case, we assign the last speaker + for word in words[word_idx:]: + if word.start < start: + # speaker change, but what make sense for assigning the word ? + # If it's a new sentence, assign with the new speaker + # If it's a continuation, assign with the last speaker + is_continuation = False + if word_idx > 0 and word_idx < len(words) - 1: + is_continuation = self.is_word_continuation( + *words[word_idx - 1 : word_idx + 1] + ) + if is_continuation: + word.speaker = last_speaker + else: + word.speaker = speaker + last_speaker = speaker + word_idx += 1 + else: + break + + # now continue to assign speaker until the word starts after the end + for word in words[word_idx:]: + if start <= word.start < end: + last_speaker = speaker + word.speaker = speaker + word_idx += 1 + elif word.start > end: + break + + # no more diarization available, + # assign last speaker to all words without speaker + for word in words[word_idx:]: + word.speaker = last_speaker diff --git a/server/tests/test_processor_audio_diarization.py b/server/tests/test_processor_audio_diarization.py new file mode 100644 index 00000000..00935a49 --- /dev/null +++ b/server/tests/test_processor_audio_diarization.py @@ -0,0 +1,140 @@ +import pytest +from unittest import mock + + +@pytest.mark.parametrize( + "name,diarization,expected", + [ + [ + "no overlap", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "same speaker", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "A"}, + ], + ["A", "A", "A", "A"], + ], + [ + # first segment is removed because it overlap + # with the second segment, and it is smaller + "overlap at 0.5s", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 0.5, "end": 2.0, "speaker": "B"}, + ], + ["B", "B", "B", "B"], + ], + [ + "junk segment at 0.5s for 0.2s", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 0.5, "end": 0.7, "speaker": "B"}, + {"start": 1, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "start without diarization", + [ + {"start": 0.5, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "end missing diarization", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 1.5, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "continuation of next speaker", + [ + {"start": 0.0, "end": 0.9, "speaker": "A"}, + {"start": 1.5, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "continuation of previous speaker", + [ + {"start": 0.0, "end": 0.5, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "segment without words", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + {"start": 2.0, "end": 3.0, "speaker": "X"}, + ], + ["A", "A", "B", "B"], + ], + ], +) +@pytest.mark.asyncio +async def test_processors_audio_diarization(event_loop, name, diarization, expected): + from reflector.processors.audio_diarization import AudioDiarizationProcessor + from reflector.processors.types import ( + TitleSummaryWithId, + Transcript, + Word, + AudioDiarizationInput, + ) + + # create fake topic + topics = [ + TitleSummaryWithId( + id="1", + title="Title1", + summary="Summary1", + timestamp=0.0, + duration=1.0, + transcript=Transcript( + words=[ + Word(text="Word1", start=0.0, end=0.5), + Word(text="word2.", start=0.5, end=1.0), + ] + ), + ), + TitleSummaryWithId( + id="2", + title="Title2", + summary="Summary2", + timestamp=0.0, + duration=1.0, + transcript=Transcript( + words=[ + Word(text="Word3", start=1.0, end=1.5), + Word(text="word4.", start=1.5, end=2.0), + ] + ), + ), + ] + + diarizer = AudioDiarizationProcessor() + with mock.patch.object(diarizer, "_diarize") as mock_diarize: + mock_diarize.return_value = diarization + + data = AudioDiarizationInput( + audio_url="https://example.com/audio.mp3", + topics=topics, + ) + await diarizer._push(data) + + # check that the speaker has been assigned to the words + assert topics[0].transcript.words[0].speaker == expected[0] + assert topics[0].transcript.words[1].speaker == expected[1] + assert topics[1].transcript.words[0].speaker == expected[2] + assert topics[1].transcript.words[1].speaker == expected[3] From eae01c1495e63278a6c40ef44153b767fddd3ca6 Mon Sep 17 00:00:00 2001 From: projects-g <63178974+projects-g@users.noreply.github.com> Date: Thu, 30 Nov 2023 22:00:06 +0530 Subject: [PATCH 23/27] Change diarization internal flow (#320) * change diarization internal flow --- server/gpu/modal/reflector_diarizer.py | 188 ++++++++++++++++++ .../processors/audio_diarization_modal.py | 2 +- 2 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 server/gpu/modal/reflector_diarizer.py diff --git a/server/gpu/modal/reflector_diarizer.py b/server/gpu/modal/reflector_diarizer.py new file mode 100644 index 00000000..7c316548 --- /dev/null +++ b/server/gpu/modal/reflector_diarizer.py @@ -0,0 +1,188 @@ +""" +Reflector GPU backend - diarizer +=================================== +""" + +import os + +import modal.gpu +from modal import Image, Secret, Stub, asgi_app, method +from pydantic import BaseModel + +PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0" +MODEL_DIR = "/root/diarization_models" + +stub = Stub(name="reflector-diarizer") + + +def migrate_cache_llm(): + """ + XXX The cache for model files in Transformers v4.22.0 has been updated. + Migrating your old cache. This is a one-time only operation. You can + interrupt this and resume the migration later on by calling + `transformers.utils.move_cache()`. + """ + from transformers.utils.hub import move_cache + + print("Moving LLM cache") + move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR) + print("LLM cache moved") + + +def download_pyannote_audio(): + from pyannote.audio import Pipeline + Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.0", + cache_dir=MODEL_DIR, + use_auth_token="***REMOVED***" + ) + + +diarizer_image = ( + Image.debian_slim(python_version="3.10.8") + .pip_install( + "pyannote.audio", + "requests", + "onnx", + "torchaudio", + "onnxruntime-gpu", + "torch==2.0.0", + "transformers==4.34.0", + "sentencepiece", + "protobuf", + "numpy", + "huggingface_hub", + "hf-transfer" + ) + .run_function(migrate_cache_llm) + .run_function(download_pyannote_audio) + .env( + { + "LD_LIBRARY_PATH": ( + "/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:" + "/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/" + ) + } + ) +) + + +@stub.cls( + gpu=modal.gpu.A100(memory=40), + timeout=60 * 10, + container_idle_timeout=60 * 3, + allow_concurrent_inputs=6, + image=diarizer_image, +) +class Diarizer: + def __enter__(self): + import torch + from pyannote.audio import Pipeline + + self.use_gpu = torch.cuda.is_available() + self.device = "cuda" if self.use_gpu else "cpu" + self.diarization_pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.0", + cache_dir=MODEL_DIR + ) + self.diarization_pipeline.to(torch.device(self.device)) + + @method() + def diarize( + self, + audio_data: str, + audio_suffix: str, + timestamp: float + ): + import tempfile + + import torchaudio + + with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: + fp.write(audio_data) + + print("Diarizing audio") + waveform, sample_rate = torchaudio.load(fp.name) + diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) + + words = [] + for diarization_segment, _, speaker in diarization.itertracks(yield_label=True): + words.append( + { + "start": round(timestamp + diarization_segment.start, 3), + "end": round(timestamp + diarization_segment.end, 3), + "speaker": int(speaker[-2:]) + } + ) + print("Diarization complete") + return { + "diarization": words + } + +# ------------------------------------------------------------------- +# Web API +# ------------------------------------------------------------------- + + +@stub.function( + timeout=60 * 10, + container_idle_timeout=60 * 3, + allow_concurrent_inputs=40, + secrets=[ + Secret.from_name("reflector-gpu"), + ], + image=diarizer_image +) +@asgi_app() +def web(): + import requests + from fastapi import Depends, FastAPI, HTTPException, status + from fastapi.security import OAuth2PasswordBearer + + diarizerstub = Diarizer() + + app = FastAPI() + + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + def apikey_auth(apikey: str = Depends(oauth2_scheme)): + if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + def validate_audio_file(audio_file_url: str): + # Check if the audio file exists + response = requests.head(audio_file_url, allow_redirects=True) + if response.status_code == 404: + raise HTTPException( + status_code=response.status_code, + detail="The audio file does not exist." + ) + + class DiarizationResponse(BaseModel): + result: dict + + @app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]) + def diarize( + audio_file_url: str, + timestamp: float = 0.0 + ) -> HTTPException | DiarizationResponse: + # Currently the uploaded files are in mp3 format + audio_suffix = "mp3" + + print("Downloading audio file") + response = requests.get(audio_file_url, allow_redirects=True) + print("Audio file downloaded successfully") + + func = diarizerstub.diarize.spawn( + audio_data=response.content, + audio_suffix=audio_suffix, + timestamp=timestamp + ) + result = func.get() + return result + + return app diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index 53de2501..511b7f70 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -31,7 +31,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor): follow_redirects=True, ) response.raise_for_status() - return response.json()["text"] + return response.json()["diarization"] AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) From 7ac6d2521737248f81aa0cd31483fbbe8216cedd Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 30 Nov 2023 17:30:08 +0100 Subject: [PATCH 24/27] server: add participant API Also break out views into different files for easier reading --- .../versions/125031f7cb78_participants.py | 30 +++ server/reflector/app.py | 10 + server/reflector/db/transcripts.py | 56 ++++- server/reflector/views/transcripts.py | 206 +----------------- server/reflector/views/transcripts_audio.py | 109 +++++++++ .../views/transcripts_participants.py | 142 ++++++++++++ server/reflector/views/transcripts_webrtc.py | 37 ++++ .../reflector/views/transcripts_websocket.py | 53 +++++ server/reflector/views/types.py | 5 + server/tests/test_transcripts_participants.py | 164 ++++++++++++++ 10 files changed, 610 insertions(+), 202 deletions(-) create mode 100644 server/migrations/versions/125031f7cb78_participants.py create mode 100644 server/reflector/views/transcripts_audio.py create mode 100644 server/reflector/views/transcripts_participants.py create mode 100644 server/reflector/views/transcripts_webrtc.py create mode 100644 server/reflector/views/transcripts_websocket.py create mode 100644 server/reflector/views/types.py create mode 100644 server/tests/test_transcripts_participants.py diff --git a/server/migrations/versions/125031f7cb78_participants.py b/server/migrations/versions/125031f7cb78_participants.py new file mode 100644 index 00000000..c345b083 --- /dev/null +++ b/server/migrations/versions/125031f7cb78_participants.py @@ -0,0 +1,30 @@ +"""participants + +Revision ID: 125031f7cb78 +Revises: 0fea6d96b096 +Create Date: 2023-11-30 15:56:03.341466 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '125031f7cb78' +down_revision: Union[str, None] = '0fea6d96b096' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('transcript', 'participants') + # ### end Alembic commands ### diff --git a/server/reflector/app.py b/server/reflector/app.py index 5bfffeca..8f45efd5 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -13,6 +13,12 @@ from reflector.metrics import metrics_init from reflector.settings import settings from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.transcripts import router as transcripts_router +from reflector.views.transcripts_audio import router as transcripts_audio_router +from reflector.views.transcripts_participants import ( + router as transcripts_participants_router, +) +from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router +from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.user import router as user_router try: @@ -60,6 +66,10 @@ metrics_init(app, instrumentator) # register views app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") +app.include_router(transcripts_audio_router, prefix="/v1") +app.include_router(transcripts_participants_router, prefix="/v1") +app.include_router(transcripts_websocket_router, prefix="/v1") +app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 0fba82ef..44688eaa 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -7,7 +7,7 @@ from uuid import uuid4 import sqlalchemy from fastapi import HTTPException -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings @@ -27,6 +27,7 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("participants", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), sqlalchemy.Column( @@ -112,6 +113,13 @@ class TranscriptEvent(BaseModel): data: dict +class TranscriptParticipant(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) + speaker: int | None + name: str + + class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) user_id: str | None = None @@ -125,6 +133,7 @@ class Transcript(BaseModel): long_summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] + participants: list[TranscriptParticipant] = [] source_language: str = "en" target_language: str = "en" share_mode: Literal["private", "semi-private", "public"] = "private" @@ -142,12 +151,34 @@ class Transcript(BaseModel): else: self.topics.append(topic) + def upsert_participant(self, participant: TranscriptParticipant): + index = next( + (i for i, p in enumerate(self.participants) if p.id == participant.id), + None, + ) + if index is not None: + self.participants[index] = participant + else: + self.participants.append(participant) + return participant + + def delete_participant(self, participant_id: str): + index = next( + (i for i, p in enumerate(self.participants) if p.id == participant_id), + None, + ) + if index is not None: + del self.participants[index] + def events_dump(self, mode="json"): return [event.model_dump(mode=mode) for event in self.events] def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] + def participants_dump(self, mode="json"): + return [participant.model_dump(mode=mode) for participant in self.participants] + def unlink(self): self.data_path.unlink(missing_ok=True) @@ -410,5 +441,28 @@ class TranscriptController: # unlink the local file transcript.audio_mp3_filename.unlink(missing_ok=True) + async def upsert_participant( + self, + transcript: Transcript, + participant: TranscriptParticipant, + ) -> TranscriptParticipant: + """ + Add/update a participant to a transcript + """ + result = transcript.upsert_participant(participant) + await self.update(transcript, {"participants": transcript.participants_dump()}) + return result + + async def delete_participant( + self, + transcript: Transcript, + participant_id: str, + ): + """ + Delete a participant from a transcript + """ + transcript.delete_participant(participant_id) + await self.update(transcript, {"participants": transcript.participants_dump()}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 44b55629..9e62192b 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,33 +1,19 @@ from datetime import datetime, timedelta from typing import Annotated, Literal, Optional -import httpx import reflector.auth as auth -from fastapi import ( - APIRouter, - Depends, - HTTPException, - Request, - Response, - WebSocket, - WebSocketDisconnect, - status, -) +from fastapi import APIRouter, Depends, HTTPException from fastapi_pagination import Page from fastapi_pagination.ext.databases import paginate from jose import jwt from pydantic import BaseModel, Field from reflector.db.transcripts import ( - AudioWaveform, + TranscriptParticipant, TranscriptTopic, transcripts_controller, ) from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings -from reflector.ws_manager import get_ws_manager - -from ._range_requests_response import range_requests_response -from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() @@ -62,6 +48,7 @@ class GetTranscript(BaseModel): share_mode: str = Field("private") source_language: str | None target_language: str | None + participants: list[TranscriptParticipant] | None class CreateTranscript(BaseModel): @@ -77,6 +64,7 @@ class UpdateTranscript(BaseModel): short_summary: Optional[str] = Field(None) long_summary: Optional[str] = Field(None) share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None) + participants: Optional[list[TranscriptParticipant]] = Field(None) class DeletionStatus(BaseModel): @@ -192,19 +180,7 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {} - if info.name is not None: - values["name"] = info.name - if info.locked is not None: - values["locked"] = info.locked - if info.long_summary is not None: - values["long_summary"] = info.long_summary - if info.short_summary is not None: - values["short_summary"] = info.short_summary - if info.title is not None: - values["title"] = info.title - if info.share_mode is not None: - values["share_mode"] = info.share_mode + values = info.dict(exclude_unset=True) await transcripts_controller.update(transcript, values) return transcript @@ -222,97 +198,6 @@ async def transcript_delete( return DeletionStatus(status="ok") -@router.get("/transcripts/{transcript_id}/audio/mp3") -@router.head("/transcripts/{transcript_id}/audio/mp3") -async def transcript_get_audio_mp3( - request: Request, - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], - token: str | None = None, -): - user_id = user["sub"] if user else None - if not user_id and token: - unauthorized_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) - user_id: str = payload.get("sub") - except jwt.JWTError: - raise unauthorized_exception - - transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id - ) - - if transcript.audio_location == "storage": - # proxy S3 file, to prevent issue with CORS - url = await transcript.get_audio_url() - headers = {} - - copy_headers = ["range", "accept-encoding"] - for header in copy_headers: - if header in request.headers: - headers[header] = request.headers[header] - - async with httpx.AsyncClient() as client: - resp = await client.request(request.method, url, headers=headers) - return Response( - content=resp.content, - status_code=resp.status_code, - headers=resp.headers, - ) - - if transcript.audio_location == "storage": - # proxy S3 file, to prevent issue with CORS - url = await transcript.get_audio_url() - headers = {} - - copy_headers = ["range", "accept-encoding"] - for header in copy_headers: - if header in request.headers: - headers[header] = request.headers[header] - - async with httpx.AsyncClient() as client: - resp = await client.request(request.method, url, headers=headers) - return Response( - content=resp.content, - status_code=resp.status_code, - headers=resp.headers, - ) - - if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=500, detail="Audio not found") - - truncated_id = str(transcript.id).split("-")[0] - filename = f"recording_{truncated_id}.mp3" - - return range_requests_response( - request, - transcript.audio_mp3_filename, - content_type="audio/mpeg", - content_disposition=f"attachment; filename={filename}", - ) - - -@router.get("/transcripts/{transcript_id}/audio/waveform") -async def transcript_get_audio_waveform( - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -) -> AudioWaveform: - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id - ) - - if not transcript.audio_waveform_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") - - return transcript.audio_waveform - - @router.get( "/transcripts/{transcript_id}/topics", response_model=list[GetTranscriptTopic], @@ -330,84 +215,3 @@ async def transcript_get_topics( return [ GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics ] - - -# ============================================================== -# Websocket -# ============================================================== - - -@router.get("/transcripts/{transcript_id}/events") -async def transcript_get_websocket_events(transcript_id: str): - pass - - -@router.websocket("/transcripts/{transcript_id}/events") -async def transcript_events_websocket( - transcript_id: str, - websocket: WebSocket, - # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - # user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - # connect to websocket manager - # use ts:transcript_id as room id - room_id = f"ts:{transcript_id}" - ws_manager = get_ws_manager() - await ws_manager.add_user_to_room(room_id, websocket) - - try: - # on first connection, send all events only to the current user - for event in transcript.events: - # for now, do not send TRANSCRIPT or STATUS options - theses are live event - # not necessary to be sent to the client; but keep the rest - name = event.event - if name in ("TRANSCRIPT", "STATUS"): - continue - await websocket.send_json(event.model_dump(mode="json")) - - # XXX if transcript is final (locked=True and status=ended) - # XXX send a final event to the client and close the connection - - # endless loop to wait for new events - # we do not have command system now, - while True: - await websocket.receive() - except (RuntimeError, WebSocketDisconnect): - await ws_manager.remove_user_from_room(room_id, websocket) - - -# ============================================================== -# Web RTC -# ============================================================== - - -@router.post("/transcripts/{transcript_id}/record/webrtc") -async def transcript_record_webrtc( - transcript_id: str, - params: RtcOffer, - request: Request, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id_for_http( - transcript_id, user_id=user_id - ) - - if transcript.locked: - raise HTTPException(status_code=400, detail="Transcript is locked") - - # create a pipeline runner - from reflector.pipelines.main_live_pipeline import PipelineMainLive - - pipeline_runner = PipelineMainLive(transcript_id=transcript_id) - - # FIXME do not allow multiple recording at the same time - return await rtc_offer_base( - params, - request, - pipeline_runner=pipeline_runner, - ) diff --git a/server/reflector/views/transcripts_audio.py b/server/reflector/views/transcripts_audio.py new file mode 100644 index 00000000..a174d992 --- /dev/null +++ b/server/reflector/views/transcripts_audio.py @@ -0,0 +1,109 @@ +""" +Transcripts audio related endpoints +=================================== + +""" +from typing import Annotated, Optional + +import httpx +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from jose import jwt +from reflector.db.transcripts import AudioWaveform, transcripts_controller +from reflector.settings import settings +from reflector.views.transcripts import ALGORITHM + +from ._range_requests_response import range_requests_response + +router = APIRouter() + + +@router.get("/transcripts/{transcript_id}/audio/mp3") +@router.head("/transcripts/{transcript_id}/audio/mp3") +async def transcript_get_audio_mp3( + request: Request, + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + token: str | None = None, +): + user_id = user["sub"] if user else None + if not user_id and token: + unauthorized_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: str = payload.get("sub") + except jwt.JWTError: + raise unauthorized_exception + + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if not transcript.audio_mp3_filename.exists(): + raise HTTPException(status_code=500, detail="Audio not found") + + truncated_id = str(transcript.id).split("-")[0] + filename = f"recording_{truncated_id}.mp3" + + return range_requests_response( + request, + transcript.audio_mp3_filename, + content_type="audio/mpeg", + content_disposition=f"attachment; filename={filename}", + ) + + +@router.get("/transcripts/{transcript_id}/audio/waveform") +async def transcript_get_audio_waveform( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> AudioWaveform: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if not transcript.audio_waveform_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + return transcript.audio_waveform diff --git a/server/reflector/views/transcripts_participants.py b/server/reflector/views/transcripts_participants.py new file mode 100644 index 00000000..318d6018 --- /dev/null +++ b/server/reflector/views/transcripts_participants.py @@ -0,0 +1,142 @@ +""" +Transcript participants API endpoints +===================================== + +""" +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, ConfigDict, Field +from reflector.db.transcripts import TranscriptParticipant, transcripts_controller +from reflector.views.types import DeletionStatus + +router = APIRouter() + + +class Participant(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + speaker: int | None + name: str + + +class CreateParticipant(BaseModel): + speaker: Optional[int] = Field(None) + name: str + + +class UpdateParticipant(BaseModel): + speaker: Optional[int] = Field(None) + name: Optional[str] = Field(None) + + +@router.get("/transcripts/{transcript_id}/participants") +async def transcript_get_participants( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> list[Participant]: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + return [ + Participant.model_validate(participant) + for participant in transcript.participants + ] + + +@router.post("/transcripts/{transcript_id}/participants") +async def transcript_add_participant( + transcript_id: str, + participant: CreateParticipant, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # ensure the speaker is unique + for p in transcript.participants: + if p.speaker == participant.speaker: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) + + obj = await transcripts_controller.upsert_participant( + transcript, TranscriptParticipant(**participant.dict()) + ) + return Participant.model_validate(obj) + + +@router.get("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_get_participant( + transcript_id: str, + participant_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + for p in transcript.participants: + if p.id == participant_id: + return Participant.model_validate(p) + + raise HTTPException(status_code=404, detail="Participant not found") + + +@router.patch("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_update_participant( + transcript_id: str, + participant_id: str, + participant: UpdateParticipant, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # ensure the speaker is unique + for p in transcript.participants: + if p.speaker == participant.speaker and p.id != participant_id: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) + + # find the participant + obj = None + for p in transcript.participants: + if p.id == participant_id: + obj = p + break + + if not obj: + raise HTTPException(status_code=404, detail="Participant not found") + + # update participant but just the fields that are set + fields = participant.dict(exclude_unset=True) + obj = obj.copy(update=fields) + + await transcripts_controller.upsert_participant(transcript, obj) + return Participant.model_validate(obj) + + +@router.delete("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_delete_participant( + transcript_id: str, + participant_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> DeletionStatus: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + await transcripts_controller.delete_participant(transcript, participant_id) + return DeletionStatus(status="ok") diff --git a/server/reflector/views/transcripts_webrtc.py b/server/reflector/views/transcripts_webrtc.py new file mode 100644 index 00000000..af451411 --- /dev/null +++ b/server/reflector/views/transcripts_webrtc.py @@ -0,0 +1,37 @@ +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, Request +from reflector.db.transcripts import transcripts_controller + +from .rtc_offer import RtcOffer, rtc_offer_base + +router = APIRouter() + + +@router.post("/transcripts/{transcript_id}/record/webrtc") +async def transcript_record_webrtc( + transcript_id: str, + params: RtcOffer, + request: Request, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.locked: + raise HTTPException(status_code=400, detail="Transcript is locked") + + # create a pipeline runner + from reflector.pipelines.main_live_pipeline import PipelineMainLive + + pipeline_runner = PipelineMainLive(transcript_id=transcript_id) + + # FIXME do not allow multiple recording at the same time + return await rtc_offer_base( + params, + request, + pipeline_runner=pipeline_runner, + ) diff --git a/server/reflector/views/transcripts_websocket.py b/server/reflector/views/transcripts_websocket.py new file mode 100644 index 00000000..65571aab --- /dev/null +++ b/server/reflector/views/transcripts_websocket.py @@ -0,0 +1,53 @@ +""" +Transcripts websocket API +========================= + +""" +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect +from reflector.db.transcripts import transcripts_controller +from reflector.ws_manager import get_ws_manager + +router = APIRouter() + + +@router.get("/transcripts/{transcript_id}/events") +async def transcript_get_websocket_events(transcript_id: str): + pass + + +@router.websocket("/transcripts/{transcript_id}/events") +async def transcript_events_websocket( + transcript_id: str, + websocket: WebSocket, + # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + # user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + # connect to websocket manager + # use ts:transcript_id as room id + room_id = f"ts:{transcript_id}" + ws_manager = get_ws_manager() + await ws_manager.add_user_to_room(room_id, websocket) + + try: + # on first connection, send all events only to the current user + for event in transcript.events: + # for now, do not send TRANSCRIPT or STATUS options - theses are live event + # not necessary to be sent to the client; but keep the rest + name = event.event + if name in ("TRANSCRIPT", "STATUS"): + continue + await websocket.send_json(event.model_dump(mode="json")) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + # we do not have command system now, + while True: + await websocket.receive() + except (RuntimeError, WebSocketDisconnect): + await ws_manager.remove_user_from_room(room_id, websocket) diff --git a/server/reflector/views/types.py b/server/reflector/views/types.py new file mode 100644 index 00000000..70361131 --- /dev/null +++ b/server/reflector/views/types.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class DeletionStatus(BaseModel): + status: str diff --git a/server/tests/test_transcripts_participants.py b/server/tests/test_transcripts_participants.py new file mode 100644 index 00000000..b55b16a8 --- /dev/null +++ b/server/tests/test_transcripts_participants.py @@ -0,0 +1,164 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_transcript_participants(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + + # create a participant + transcript_id = response.json()["id"] + response = await ac.post( + f"/transcripts/{transcript_id}/participants", json={"name": "test"} + ) + assert response.status_code == 200 + assert response.json()["id"] is not None + assert response.json()["speaker"] is None + assert response.json()["name"] == "test" + + # create another one with a speaker + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["id"] is not None + assert response.json()["speaker"] == 1 + assert response.json()["name"] == "test2" + + # get all participants via transcript + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + assert len(response.json()["participants"]) == 2 + + # get participants via participants endpoint + response = await ac.get(f"/transcripts/{transcript_id}/participants") + assert response.status_code == 200 + assert len(response.json()) == 2 + + +@pytest.mark.asyncio +async def test_transcript_participants_same_speaker(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["speaker"] == 1 + + # create another one with the same speaker + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 1}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_transcript_participants_update_name(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["speaker"] == 1 + + # update the participant + participant_id = response.json()["id"] + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant_id}", + json={"name": "test2"}, + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + # verify the participant was updated + response = await ac.get( + f"/transcripts/{transcript_id}/participants/{participant_id}" + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + # verify the participant was updated in transcript + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + assert len(response.json()["participants"]) == 1 + assert response.json()["participants"][0]["name"] == "test2" + + +@pytest.mark.asyncio +async def test_transcript_participants_update_speaker(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + participant1_id = response.json()["id"] + + # create another participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 2}, + ) + assert response.status_code == 200 + participant2_id = response.json()["id"] + + # update the participant, refused as speaker is already taken + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant2_id}", + json={"speaker": 1}, + ) + assert response.status_code == 400 + + # delete the participant 1 + response = await ac.delete( + f"/transcripts/{transcript_id}/participants/{participant1_id}" + ) + assert response.status_code == 200 + + # update the participant 2 again, should be accepted now + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant2_id}", + json={"speaker": 1}, + ) + assert response.status_code == 200 + + # ensure participant2 name is still there + response = await ac.get( + f"/transcripts/{transcript_id}/participants/{participant2_id}" + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + assert response.json()["speaker"] == 1 From 8b1b71940f18ddf7a669845acd55b444dd5d005d Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 30 Nov 2023 19:25:09 +0100 Subject: [PATCH 25/27] hotfix/server: update diarization settings to increase timeout, reduce idle timeout on the minimum --- server/gpu/modal/reflector_diarizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/gpu/modal/reflector_diarizer.py b/server/gpu/modal/reflector_diarizer.py index 7c316548..b1989a11 100644 --- a/server/gpu/modal/reflector_diarizer.py +++ b/server/gpu/modal/reflector_diarizer.py @@ -69,9 +69,9 @@ diarizer_image = ( @stub.cls( gpu=modal.gpu.A100(memory=40), - timeout=60 * 10, - container_idle_timeout=60 * 3, - allow_concurrent_inputs=6, + timeout=60 * 30, + container_idle_timeout=60, + allow_concurrent_inputs=1, image=diarizer_image, ) class Diarizer: From f9771427e2f7c363ae775006d61e7cb3f506faf4 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 30 Nov 2023 19:43:19 +0100 Subject: [PATCH 26/27] server: add healthcheck for worker --- README.md | 14 +++++++++++--- server/reflector/settings.py | 3 +++ server/reflector/worker/app.py | 15 +++++++++++++++ server/reflector/worker/healthcheck.py | 18 ++++++++++++++++++ server/runserver.sh | 2 ++ 5 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 server/reflector/worker/healthcheck.py diff --git a/README.md b/README.md index cb75c76b..8830c79e 100644 --- a/README.md +++ b/README.md @@ -133,17 +133,25 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run ``` -### Start the project +### Start the API/Backend -Use: +Start the API server: ```bash poetry run python3 -m reflector.app ``` -And start the background worker +Start the background worker: +```bash celery -A reflector.worker.app worker --loglevel=info +``` + +For crontab (only healthcheck for now), start the celery beat: + +```bash +celery -A reflector.worker.app beat +``` #### Using docker diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 2c68c4e5..d0ddc91a 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -128,5 +128,8 @@ class Settings(BaseSettings): # Profiling PROFILING: bool = False + # Healthcheck + HEALTHCHECK_URL: str | None = None + settings = Settings() diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py index e1000364..689623ce 100644 --- a/server/reflector/worker/app.py +++ b/server/reflector/worker/app.py @@ -1,6 +1,8 @@ +import structlog from celery import Celery from reflector.settings import settings +logger = structlog.get_logger(__name__) app = Celery(__name__) app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.result_backend = settings.CELERY_RESULT_BACKEND @@ -8,5 +10,18 @@ app.conf.broker_connection_retry_on_startup = True app.autodiscover_tasks( [ "reflector.pipelines.main_live_pipeline", + "reflector.worker.healthcheck", ] ) + +# crontab +app.conf.beat_schedule = {} + +if settings.HEALTHCHECK_URL: + app.conf.beat_schedule["healthcheck_ping"] = { + "task": "reflector.worker.healthcheck.healthcheck_ping", + "schedule": 60.0 * 10, + } + logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL) +else: + logger.warning("Healthcheck disabled, no url configured") diff --git a/server/reflector/worker/healthcheck.py b/server/reflector/worker/healthcheck.py new file mode 100644 index 00000000..e4ce6bc3 --- /dev/null +++ b/server/reflector/worker/healthcheck.py @@ -0,0 +1,18 @@ +import httpx +import structlog +from celery import shared_task +from reflector.settings import settings + +logger = structlog.get_logger(__name__) + + +@shared_task +def healthcheck_ping(): + url = settings.HEALTHCHECK_URL + if not url: + return + try: + print("pinging healthcheck url", url) + httpx.get(url, timeout=10) + except Exception as e: + logger.error("healthcheck_ping", error=str(e)) diff --git a/server/runserver.sh b/server/runserver.sh index b0c3f138..31cce123 100755 --- a/server/runserver.sh +++ b/server/runserver.sh @@ -9,6 +9,8 @@ if [ "${ENTRYPOINT}" = "server" ]; then python -m reflector.app elif [ "${ENTRYPOINT}" = "worker" ]; then celery -A reflector.worker.app worker --loglevel=info +elif [ "${ENTRYPOINT}" = "beat" ]; then + celery -A reflector.worker.app beat --loglevel=info else echo "Unknown command" fi From 84a1350df7eca523743e5e602ad73cecab058ccf Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 1 Dec 2023 18:18:09 +0100 Subject: [PATCH 27/27] hotfix/server: fix participants loading on old meetings --- server/reflector/db/transcripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 44688eaa..970393d5 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -133,7 +133,7 @@ class Transcript(BaseModel): long_summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] - participants: list[TranscriptParticipant] = [] + participants: list[TranscriptParticipant] | None = [] source_language: str = "en" target_language: str = "en" share_mode: Literal["private", "semi-private", "public"] = "private"