From 226b92c3474e05787643a77f2efde7abe132a346 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Tue, 7 Nov 2023 12:39:48 +0100
Subject: [PATCH 01/27] www/server: introduce share mode
---
.../versions/0fea6d96b096_add_share_mode.py | 30 +++++++
server/reflector/db/transcripts.py | 55 +++++++++++-
server/reflector/views/transcripts.py | 56 ++++++++-----
.../transcripts/[transcriptId]/page.tsx | 7 +-
www/app/[domain]/transcripts/shareLink.tsx | 84 ++++++++++++++++---
www/app/api/models/GetTranscript.ts | 17 ++++
www/app/api/models/UpdateTranscript.ts | 8 ++
www/app/styles/form.scss | 5 ++
8 files changed, 228 insertions(+), 34 deletions(-)
create mode 100644 server/migrations/versions/0fea6d96b096_add_share_mode.py
diff --git a/server/migrations/versions/0fea6d96b096_add_share_mode.py b/server/migrations/versions/0fea6d96b096_add_share_mode.py
new file mode 100644
index 00000000..52a72d48
--- /dev/null
+++ b/server/migrations/versions/0fea6d96b096_add_share_mode.py
@@ -0,0 +1,30 @@
+"""add share_mode
+
+Revision ID: 0fea6d96b096
+Revises: 38a927dcb099
+Create Date: 2023-11-07 11:12:21.614198
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '0fea6d96b096'
+down_revision: Union[str, None] = '38a927dcb099'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('transcript', sa.Column('share_mode', sa.String(), server_default='private', nullable=False))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('transcript', 'share_mode')
+ # ### end Alembic commands ###
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 6ac2e32a..4b91423a 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -2,10 +2,11 @@ import json
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
-from typing import Any
+from typing import Any, Literal
from uuid import uuid4
import sqlalchemy
+from fastapi import HTTPException
from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
@@ -30,6 +31,12 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
+ sqlalchemy.Column(
+ "share_mode",
+ sqlalchemy.String,
+ nullable=False,
+ server_default="private",
+ ),
)
@@ -99,6 +106,7 @@ class Transcript(BaseModel):
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
+ share_mode: Literal["private", "semi-private", "public"] = "private"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -169,6 +177,7 @@ class TranscriptController:
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
+ return_query: bool = False,
) -> list[Transcript]:
"""
Get all transcripts
@@ -195,6 +204,9 @@ class TranscriptController:
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
+ if return_query:
+ return query
+
results = await database.fetch_all(query)
return results
@@ -210,6 +222,47 @@ class TranscriptController:
return None
return Transcript(**result)
+ async def get_by_id_for_http(
+ self,
+ transcript_id: str,
+ user_id: str | None,
+ ) -> Transcript:
+ """
+ Get a transcript by ID for HTTP request.
+
+ If not found, it will raise a 404 error.
+ If the user is not allowed to access the transcript, it will raise a 403 error.
+
+ This method checks the share mode of the transcript and the user_id
+ to determine if the user can access the transcript.
+ """
+ query = transcripts.select().where(transcripts.c.id == transcript_id)
+ result = await database.fetch_one(query)
+ if not result:
+ raise HTTPException(status_code=404, detail="Transcript not found")
+
+ # if the transcript is anonymous, share mode is not checked
+ transcript = Transcript(**result)
+ if transcript.user_id is None:
+ return transcript
+
+ if transcript.share_mode == "private":
+ # in private mode, only the owner can access the transcript
+ if transcript.user_id == user_id:
+ return transcript
+
+ elif transcript.share_mode == "semi-private":
+ # in semi-private mode, only the owner and the users with the link
+ # can access the transcript
+ if user_id is not None:
+ return transcript
+
+ elif transcript.share_mode == "public":
+ # in public mode, everyone can access the transcript
+ return transcript
+
+ raise HTTPException(status_code=403, detail="Transcript access denied")
+
async def add(
self,
name: str,
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index e3668ecb..5f1d7831 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
-from typing import Annotated, Optional
+from typing import Annotated, Literal, Optional
import reflector.auth as auth
from fastapi import (
@@ -11,7 +11,8 @@ from fastapi import (
WebSocketDisconnect,
status,
)
-from fastapi_pagination import Page, paginate
+from fastapi_pagination import Page
+from fastapi_pagination.ext.databases import paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
@@ -48,6 +49,7 @@ def create_access_token(data: dict, expires_delta: timedelta):
class GetTranscript(BaseModel):
id: str
+ user_id: str | None
name: str
status: str
locked: bool
@@ -56,6 +58,7 @@ class GetTranscript(BaseModel):
short_summary: str | None
long_summary: str | None
created_at: datetime
+ share_mode: str = Field("private")
source_language: str | None
target_language: str | None
@@ -72,6 +75,7 @@ class UpdateTranscript(BaseModel):
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
+ share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
class DeletionStatus(BaseModel):
@@ -82,12 +86,19 @@ class DeletionStatus(BaseModel):
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
+ from reflector.db import database
+
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
- return paginate(
- await transcripts_controller.get_all(user_id=user_id, order_by="-created_at")
+ return await paginate(
+ database,
+ await transcripts_controller.get_all(
+ user_id=user_id,
+ order_by="-created_at",
+ return_query=True,
+ ),
)
@@ -165,10 +176,9 @@ async def transcript_get(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
- return transcript
+ return await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
@@ -192,6 +202,8 @@ async def transcript_update(
values["short_summary"] = info.short_summary
if info.title is not None:
values["title"] = info.title
+ if info.share_mode is not None:
+ values["share_mode"] = info.share_mode
await transcripts_controller.update(transcript, values)
return transcript
@@ -229,12 +241,12 @@ async def transcript_get_audio_mp3(
except jwt.JWTError:
raise unauthorized_exception
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
if not transcript.audio_mp3_filename.exists():
- raise HTTPException(status_code=404, detail="Audio not found")
+ raise HTTPException(status_code=500, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
@@ -253,12 +265,12 @@ async def transcript_get_audio_waveform(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
if not transcript.audio_mp3_filename.exists():
- raise HTTPException(status_code=404, detail="Audio not found")
+ raise HTTPException(status_code=500, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
@@ -274,9 +286,9 @@ async def transcript_get_topics(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
# convert to GetTranscriptTopic
return [
@@ -345,9 +357,9 @@ async def transcript_record_webrtc(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 9f9348c8..7b57ff2a 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -99,7 +99,12 @@ export default function TranscriptDetails(details: TranscriptDetails) {
/>
-
+
diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx
index 49163a5b..44c57053 100644
--- a/www/app/[domain]/transcripts/shareLink.tsx
+++ b/www/app/[domain]/transcripts/shareLink.tsx
@@ -1,15 +1,37 @@
import React, { useState, useRef, useEffect, use } from "react";
import { featureEnabled } from "../domainContext";
+import getApi from "../../lib/getApi";
+import { useFiefUserinfo } from "@fief/fief/nextjs/react";
+import SelectSearch from "react-select-search";
+import "react-select-search/style.css";
+import "../../styles/button.css";
+import "../../styles/form.scss";
-const ShareLink = () => {
+type ShareLinkProps = {
+ protectedPath: boolean;
+ transcriptId: string;
+ userId: string | null;
+ shareMode: string;
+};
+
+const ShareLink = (props: ShareLinkProps) => {
const [isCopied, setIsCopied] = useState(false);
const inputRef = useRef(null);
const [currentUrl, setCurrentUrl] = useState("");
+ const requireLogin = featureEnabled("requireLogin");
+ const [isOwner, setIsOwner] = useState(false);
+ const [shareMode, setShareMode] = useState(props.shareMode);
+ const api = getApi(props.protectedPath);
+ const userinfo = useFiefUserinfo();
useEffect(() => {
setCurrentUrl(window.location.href);
}, []);
+ useEffect(() => {
+ setIsOwner(!!(requireLogin && userinfo?.sub === props.userId));
+ }, [userinfo, props.userId]);
+
const handleCopyClick = () => {
if (inputRef.current) {
let text_to_copy = inputRef.current.value;
@@ -23,6 +45,16 @@ const ShareLink = () => {
}
};
+ const updateShareMode = async (selectedShareMode: string) => {
+ if (!api) return;
+ const updatedTranscript = await api.v1TranscriptUpdate({
+ transcriptId: props.transcriptId,
+ updateTranscript: {
+ shareMode: selectedShareMode,
+ },
+ });
+ setShareMode(updatedTranscript.shareMode);
+ };
const privacyEnabled = featureEnabled("privacy");
return (
@@ -30,18 +62,50 @@ const ShareLink = () => {
className="p-2 md:p-4 rounded"
style={{ background: "rgba(96, 165, 250, 0.2)" }}
>
- {privacyEnabled ? (
+ {requireLogin && (
- You can share this link with others. Anyone with the link will have
- access to the page, including the full audio recording, for the next 7
- days.
-
- ) : (
-
- You can share this link with others. Anyone with the link will have
- access to the page, including the full audio recording.
+ {shareMode === "private" && (
+
This transcript is only accessible by you.
+ )}
+ {shareMode === "semi-private" && (
+ This transcript is accessible by any authenticated users.
+ )}
+ {shareMode === "public" && (
+ This transcript is accessible by anyone.
+ )}
+
+ {isOwner && api && (
+
+
+
+ )}
)}
+ {!requireLogin && (
+ <>
+ {privacyEnabled ? (
+
+ You can share this link with others. Anyone with the link will
+ have access to the page, including the full audio recording, for
+ the next 7 days.
+
+ ) : (
+
+ You can share this link with others. Anyone with the link will
+ have access to the page, including the full audio recording.
+
+ )}
+ >
+ )}
Date: Tue, 7 Nov 2023 18:41:51 +0100
Subject: [PATCH 02/27] www: edit from andreas feedback
---
www/app/[domain]/transcripts/shareLink.tsx | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx
index 44c57053..82ef52c9 100644
--- a/www/app/[domain]/transcripts/shareLink.tsx
+++ b/www/app/[domain]/transcripts/shareLink.tsx
@@ -65,13 +65,15 @@ const ShareLink = (props: ShareLinkProps) => {
{requireLogin && (
{shareMode === "private" && (
-
This transcript is only accessible by you.
+
This transcript is private and can only be accessed by you.
)}
{shareMode === "semi-private" && (
-
This transcript is accessible by any authenticated users.
+
+ This transcript is secure. Only authenticated users can access it.
+
)}
{shareMode === "public" && (
-
This transcript is accessible by anyone.
+
This transcript is public. Everyone can access it.
)}
{isOwner && api && (
@@ -80,7 +82,7 @@ const ShareLink = (props: ShareLinkProps) => {
className="select-search--top select-search"
options={[
{ name: "Private", value: "private" },
- { name: "Semi-private", value: "semi-private" },
+ { name: "Secure", value: "semi-private" },
{ name: "Public", value: "public" },
]}
value={shareMode}
@@ -94,14 +96,14 @@ const ShareLink = (props: ShareLinkProps) => {
<>
{privacyEnabled ? (
- You can share this link with others. Anyone with the link will
- have access to the page, including the full audio recording, for
- the next 7 days.
+ Share this link to grant others access to this page. The link
+ includes the full audio recording and is valid for the next 7
+ days.
) : (
- You can share this link with others. Anyone with the link will
- have access to the page, including the full audio recording.
+ Share this link to allow others to view this page and listen to
+ the full audio recording.
)}
>
From 86b3b3c0e4f61d98d1b4a0dbbb343286db503371 Mon Sep 17 00:00:00 2001
From: Sara
Date: Mon, 13 Nov 2023 17:23:27 +0100
Subject: [PATCH 03/27] split recorder player
---
.../transcripts/[transcriptId]/page.tsx | 10 +-
www/app/[domain]/transcripts/player.tsx | 167 ++++++++++++++++++
www/app/[domain]/transcripts/recorder.tsx | 109 +-----------
3 files changed, 174 insertions(+), 112 deletions(-)
create mode 100644 www/app/[domain]/transcripts/player.tsx
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 56201c3c..bebfe261 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -5,7 +5,6 @@ import useTopics from "../useTopics";
import useWaveform from "../useWaveform";
import useMp3 from "../useMp3";
import { TopicList } from "../topicList";
-import Recorder from "../recorder";
import { Topic } from "../webSocketTypes";
import React, { useState } from "react";
import "../../../styles/button.css";
@@ -13,6 +12,7 @@ import FinalSummary from "../finalSummary";
import ShareLink from "../shareLink";
import QRCode from "react-qr-code";
import TranscriptTitle from "../transcriptTitle";
+import Player from "../player";
type TranscriptDetails = {
params: {
@@ -62,14 +62,12 @@ export default function TranscriptDetails(details: TranscriptDetails) {
/>
)}
{!waveform?.loading && (
-
)}
diff --git a/www/app/[domain]/transcripts/player.tsx b/www/app/[domain]/transcripts/player.tsx
new file mode 100644
index 00000000..6143836e
--- /dev/null
+++ b/www/app/[domain]/transcripts/player.tsx
@@ -0,0 +1,167 @@
+import React, { useRef, useEffect, useState } from "react";
+
+import WaveSurfer from "wavesurfer.js";
+import CustomRegionsPlugin from "../../lib/custom-plugins/regions";
+
+import { formatTime } from "../../lib/time";
+import { Topic } from "./webSocketTypes";
+import { AudioWaveform } from "../../api";
+import { waveSurferStyles } from "../../styles/recorder";
+
+type PlayerProps = {
+ topics: Topic[];
+ useActiveTopic: [
+ Topic | null,
+ React.Dispatch>,
+ ];
+ waveform: AudioWaveform;
+ media: HTMLMediaElement;
+ mediaDuration: number;
+};
+
+export default function Player(props: PlayerProps) {
+ const waveformRef = useRef(null);
+ const [wavesurfer, setWavesurfer] = useState(null);
+ const [isPlaying, setIsPlaying] = useState(false);
+ const [currentTime, setCurrentTime] = useState(0);
+ const [waveRegions, setWaveRegions] = useState(
+ null,
+ );
+ const [activeTopic, setActiveTopic] = props.useActiveTopic;
+ const topicsRef = useRef(props.topics);
+
+ // Waveform setup
+ useEffect(() => {
+ if (waveformRef.current) {
+ // XXX duration is required to prevent recomputing peaks from audio
+ // However, the current waveform returns only the peaks, and no duration
+ // And the backend does not save duration properly.
+ // So at the moment, we deduct the duration from the topics.
+ // This is not ideal, but it works for now.
+ const _wavesurfer = WaveSurfer.create({
+ container: waveformRef.current,
+ peaks: props.waveform.data,
+ hideScrollbar: true,
+ autoCenter: true,
+ barWidth: 2,
+ height: "auto",
+ duration: props.mediaDuration,
+
+ ...waveSurferStyles.player,
+ });
+
+ // styling
+ const wsWrapper = _wavesurfer.getWrapper();
+ wsWrapper.style.cursor = waveSurferStyles.playerStyle.cursor;
+ wsWrapper.style.backgroundColor =
+ waveSurferStyles.playerStyle.backgroundColor;
+ wsWrapper.style.borderRadius = waveSurferStyles.playerStyle.borderRadius;
+
+ _wavesurfer.on("play", () => {
+ setIsPlaying(true);
+ });
+ _wavesurfer.on("pause", () => {
+ setIsPlaying(false);
+ });
+ _wavesurfer.on("timeupdate", setCurrentTime);
+
+ setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create()));
+
+ _wavesurfer.toggleInteraction(true);
+
+ _wavesurfer.setMediaElement(props.media);
+
+ setWavesurfer(_wavesurfer);
+
+ return () => {
+ _wavesurfer.destroy();
+ setIsPlaying(false);
+ setCurrentTime(0);
+ };
+ }
+ }, []);
+
+ useEffect(() => {
+ if (!wavesurfer) return;
+ if (!props.media) return;
+ wavesurfer.setMediaElement(props.media);
+ }, [props.media, wavesurfer]);
+
+ useEffect(() => {
+ topicsRef.current = props.topics;
+ renderMarkers();
+ }, [props.topics, waveRegions]);
+
+ const renderMarkers = () => {
+ if (!waveRegions) return;
+
+ waveRegions.clearRegions();
+
+ for (let topic of topicsRef.current) {
+ const content = document.createElement("div");
+ content.setAttribute("style", waveSurferStyles.marker);
+ content.onmouseover = () => {
+ content.style.backgroundColor =
+ waveSurferStyles.markerHover.backgroundColor;
+ content.style.zIndex = "999";
+ content.style.width = "300px";
+ };
+ content.onmouseout = () => {
+ content.setAttribute("style", waveSurferStyles.marker);
+ };
+ content.textContent = topic.title;
+
+ const region = waveRegions.addRegion({
+ start: topic.timestamp,
+ content,
+ color: "f00",
+ drag: false,
+ });
+ region.on("click", (e) => {
+ e.stopPropagation();
+ setActiveTopic(topic);
+ wavesurfer?.setTime(region.start);
+ });
+ }
+ };
+
+ useEffect(() => {
+ if (activeTopic) {
+ wavesurfer?.setTime(activeTopic.timestamp);
+ }
+ }, [activeTopic]);
+
+ const handlePlayClick = () => {
+ wavesurfer?.playPause();
+ };
+
+ const timeLabel = () => {
+ if (props.mediaDuration)
+ return `${formatTime(currentTime)}/${formatTime(props.mediaDuration)}`;
+ return "";
+ };
+
+ return (
+
+
+
+
+ {isPlaying ? "Pause" : "Play"}
+
+
+ );
+}
diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx
index 8db32ff7..5b4420c4 100644
--- a/www/app/[domain]/transcripts/recorder.tsx
+++ b/www/app/[domain]/transcripts/recorder.tsx
@@ -6,11 +6,8 @@ import CustomRegionsPlugin from "../../lib/custom-plugins/regions";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faMicrophone } from "@fortawesome/free-solid-svg-icons";
-import { faDownload } from "@fortawesome/free-solid-svg-icons";
import { formatTime } from "../../lib/time";
-import { Topic } from "./webSocketTypes";
-import { AudioWaveform } from "../../api";
import AudioInputsDropdown from "./audioInputsDropdown";
import { Option } from "react-dropdown";
import { waveSurferStyles } from "../../styles/recorder";
@@ -19,17 +16,8 @@ import { useError } from "../../(errors)/errorContext";
type RecorderProps = {
setStream?: React.Dispatch>;
onStop?: () => void;
- topics: Topic[];
getAudioStream?: (deviceId) => Promise;
audioDevices?: Option[];
- useActiveTopic: [
- Topic | null,
- React.Dispatch>,
- ];
- waveform?: AudioWaveform | null;
- isPastMeeting: boolean;
- transcriptId?: string | null;
- media?: HTMLMediaElement | null;
mediaDuration?: number | null;
};
@@ -38,7 +26,7 @@ export default function Recorder(props: RecorderProps) {
const [wavesurfer, setWavesurfer] = useState(null);
const [record, setRecord] = useState(null);
const [isRecording, setIsRecording] = useState(false);
- const [hasRecorded, setHasRecorded] = useState(props.isPastMeeting);
+ const [hasRecorded, setHasRecorded] = useState(false);
const [isPlaying, setIsPlaying] = useState(false);
const [currentTime, setCurrentTime] = useState(0);
const [timeInterval, setTimeInterval] = useState(null);
@@ -48,8 +36,6 @@ export default function Recorder(props: RecorderProps) {
);
const [deviceId, setDeviceId] = useState(null);
const [recordStarted, setRecordStarted] = useState(false);
- const [activeTopic, setActiveTopic] = props.useActiveTopic;
- const topicsRef = useRef(props.topics);
const [showDevices, setShowDevices] = useState(false);
const { setError } = useError();
@@ -73,8 +59,6 @@ export default function Recorder(props: RecorderProps) {
if (!record.isRecording()) return;
handleRecClick();
break;
- case "^":
- throw new Error("Unhandled Exception thrown by '^' shortcut");
case "(":
location.href = "/login";
break;
@@ -104,14 +88,8 @@ export default function Recorder(props: RecorderProps) {
// Waveform setup
useEffect(() => {
if (waveformRef.current) {
- // XXX duration is required to prevent recomputing peaks from audio
- // However, the current waveform returns only the peaks, and no duration
- // And the backend does not save duration properly.
- // So at the moment, we deduct the duration from the topics.
- // This is not ideal, but it works for now.
const _wavesurfer = WaveSurfer.create({
container: waveformRef.current,
- peaks: props.waveform?.data,
hideScrollbar: true,
autoCenter: true,
barWidth: 2,
@@ -121,10 +99,8 @@ export default function Recorder(props: RecorderProps) {
...waveSurferStyles.player,
});
- if (!props.transcriptId) {
- const _wshack: any = _wavesurfer;
- _wshack.renderer.renderSingleCanvas = () => {};
- }
+ const _wshack: any = _wavesurfer;
+ _wshack.renderer.renderSingleCanvas = () => {};
// styling
const wsWrapper = _wavesurfer.getWrapper();
@@ -144,12 +120,6 @@ export default function Recorder(props: RecorderProps) {
setRecord(_wavesurfer.registerPlugin(RecordPlugin.create()));
setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create()));
- if (props.isPastMeeting) _wavesurfer.toggleInteraction(true);
-
- if (props.media) {
- _wavesurfer.setMediaElement(props.media);
- }
-
setWavesurfer(_wavesurfer);
return () => {
@@ -161,58 +131,6 @@ export default function Recorder(props: RecorderProps) {
}
}, []);
- useEffect(() => {
- if (!wavesurfer) return;
- if (!props.media) return;
- wavesurfer.setMediaElement(props.media);
- }, [props.media, wavesurfer]);
-
- useEffect(() => {
- topicsRef.current = props.topics;
- if (!isRecording) renderMarkers();
- }, [props.topics, waveRegions]);
-
- const renderMarkers = () => {
- if (!waveRegions) return;
-
- waveRegions.clearRegions();
-
- for (let topic of topicsRef.current) {
- const content = document.createElement("div");
- content.setAttribute("style", waveSurferStyles.marker);
- content.onmouseover = () => {
- content.style.backgroundColor =
- waveSurferStyles.markerHover.backgroundColor;
- content.style.zIndex = "999";
- content.style.width = "300px";
- };
- content.onmouseout = () => {
- content.setAttribute("style", waveSurferStyles.marker);
- };
- content.textContent = topic.title;
-
- const region = waveRegions.addRegion({
- start: topic.timestamp,
- content,
- color: "f00",
- drag: false,
- });
- region.on("click", (e) => {
- e.stopPropagation();
- setActiveTopic(topic);
- wavesurfer?.setTime(region.start);
- });
- }
- };
-
- useEffect(() => {
- if (!record) return;
-
- return record.on("stopRecording", () => {
- renderMarkers();
- });
- }, [record]);
-
useEffect(() => {
if (isRecording) {
const interval = window.setInterval(() => {
@@ -229,12 +147,6 @@ export default function Recorder(props: RecorderProps) {
}
}, [isRecording]);
- useEffect(() => {
- if (activeTopic) {
- wavesurfer?.setTime(activeTopic.timestamp);
- }
- }, [activeTopic]);
-
const handleRecClick = async () => {
if (!record) return console.log("no record");
@@ -320,7 +232,6 @@ export default function Recorder(props: RecorderProps) {
if (!record) return;
if (!destinationStream) return;
if (props.setStream) props.setStream(destinationStream);
- waveRegions?.clearRegions();
if (destinationStream) {
record.startRecording(destinationStream);
setIsRecording(true);
@@ -379,23 +290,9 @@ export default function Recorder(props: RecorderProps) {
} text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`}
id="play-btn"
onClick={handlePlayClick}
- disabled={isRecording}
>
{isPlaying ? "Pause" : "Play"}
-
- {props.transcriptId && (
-
-
-
- )}
>
)}
{!hasRecorded && (
From 14ebfa53a8c2f9abad2342ebe7d3ab4a018e52f5 Mon Sep 17 00:00:00 2001
From: Sara
Date: Mon, 13 Nov 2023 17:33:12 +0100
Subject: [PATCH 04/27] wip loading and redirects
---
.../transcripts/[transcriptId]/page.tsx | 38 +++++++++++++++----
.../[transcriptId]/record/page.tsx | 26 ++++++++-----
www/app/[domain]/transcripts/useTranscript.ts | 27 ++++++++++---
www/app/[domain]/transcripts/useWebSockets.ts | 30 +++++++++------
4 files changed, 87 insertions(+), 34 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index bebfe261..add5a824 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -6,7 +6,7 @@ import useWaveform from "../useWaveform";
import useMp3 from "../useMp3";
import { TopicList } from "../topicList";
import { Topic } from "../webSocketTypes";
-import React, { useState } from "react";
+import React, { useEffect, useState } from "react";
import "../../../styles/button.css";
import FinalSummary from "../finalSummary";
import ShareLink from "../shareLink";
@@ -31,7 +31,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const useActiveTopic = useState(null);
const mp3 = useMp3(protectedPath, transcriptId);
- if (transcript?.error /** || topics?.error || waveform?.error **/) {
+ if (transcript?.error || topics?.error) {
return (
{
+ const statusToRedirect = ["idle", "recording", "processing"];
+ if (statusToRedirect.includes(transcript.response?.status)) {
+ const newUrl = "/transcripts/" + details.params.transcriptId + "/record";
+ // Shallow redirection does not work on NextJS 13
+ // https://github.com/vercel/next.js/discussions/48110
+ // https://github.com/vercel/next.js/discussions/49540
+ // router.push(newUrl, undefined, { shallow: true });
+ history.replaceState({}, "", newUrl);
+ }
+ }, [transcript.response?.status]);
+
const fullTranscript =
topics.topics
?.map((topic) => topic.transcript)
.join("\n\n")
.replace(/ +/g, " ")
.trim() || "";
+ console.log("calf full transcript");
return (
<>
- {!transcriptId || transcript?.loading || topics?.loading ? (
+ {transcript?.loading || topics?.loading ? (
) : (
<>
@@ -61,7 +74,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
transcriptId={transcript.response.id}
/>
)}
- {!waveform?.loading && (
+ {waveform.waveform && mp3.media ? (
+ ) : mp3.error || waveform.error ? (
+ "error loading this recording"
+ ) : (
+ "Loading Recording"
)}
+
- {transcript?.response?.longSummary && (
+ {transcript.response.longSummary ? (
+ ) : transcript.response.status == "processing" ? (
+ "Loading Transcript"
+ ) : (
+ "error final summary"
)}
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
index 41a2d053..10297f5c 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
@@ -8,12 +8,12 @@ import { useWebSockets } from "../../useWebSockets";
import useAudioDevice from "../../useAudioDevice";
import "../../../../styles/button.css";
import { Topic } from "../../webSocketTypes";
-import getApi from "../../../../lib/getApi";
import LiveTrancription from "../../liveTranscription";
import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faGear } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
+import { useRouter } from "next/navigation";
type TranscriptDetails = {
params: {
@@ -45,21 +45,31 @@ const TranscriptRecord = (details: TranscriptDetails) => {
const [hasRecorded, setHasRecorded] = useState(false);
const [transcriptStarted, setTranscriptStarted] = useState(false);
+ const router = useRouter();
+
useEffect(() => {
if (!transcriptStarted && webSockets.transcriptText.length !== 0)
setTranscriptStarted(true);
}, [webSockets.transcriptText]);
useEffect(() => {
- if (transcript?.response?.longSummary) {
- const newUrl = `/transcripts/${transcript.response.id}`;
+ const statusToRedirect = ["ended", "error"];
+ console.log(webSockets.status, "hey");
+ console.log(transcript.response, "ho");
+
+ //TODO if has no topic and is error, get back to new
+ if (
+ statusToRedirect.includes(transcript.response?.status) ||
+ statusToRedirect.includes(webSockets.status.value)
+ ) {
+ const newUrl = "/transcripts/" + details.params.transcriptId;
// Shallow redirection does not work on NextJS 13
// https://github.com/vercel/next.js/discussions/48110
// https://github.com/vercel/next.js/discussions/49540
- // router.push(newUrl, undefined, { shallow: true });
- history.replaceState({}, "", newUrl);
+ router.push(newUrl, undefined);
+ // history.replaceState({}, "", newUrl);
}
- });
+ }, [webSockets.status.value, transcript.response?.status]);
useEffect(() => {
lockWakeState();
@@ -77,10 +87,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
setHasRecorded(true);
webRTC?.send(JSON.stringify({ cmd: "STOP" }));
}}
- topics={webSockets.topics}
getAudioStream={getAudioStream}
- useActiveTopic={useActiveTopic}
- isPastMeeting={false}
audioDevices={audioDevices}
/>
@@ -128,6 +135,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
couple of minutes. Please do not navigate away from the page
during this time.
+ {/* TODO If login required remove last sentence */}
)}
diff --git a/www/app/[domain]/transcripts/useTranscript.ts b/www/app/[domain]/transcripts/useTranscript.ts
index af60cd3b..987e57f3 100644
--- a/www/app/[domain]/transcripts/useTranscript.ts
+++ b/www/app/[domain]/transcripts/useTranscript.ts
@@ -5,16 +5,28 @@ import { useError } from "../../(errors)/errorContext";
import getApi from "../../lib/getApi";
import { shouldShowError } from "../../lib/errorUtils";
-type Transcript = {
- response: GetTranscript | null;
- loading: boolean;
- error: Error | null;
+type ErrorTranscript = {
+ error: Error;
+ loading: false;
+ response: any;
+};
+
+type LoadingTranscript = {
+ response: any;
+ loading: true;
+ error: false;
+};
+
+type SuccessTranscript = {
+ response: GetTranscript;
+ loading: false;
+ error: null;
};
const useTranscript = (
protectedPath: boolean,
id: string | null,
-): Transcript => {
+): ErrorTranscript | LoadingTranscript | SuccessTranscript => {
const [response, setResponse] = useState
(null);
const [loading, setLoading] = useState(true);
const [error, setErrorState] = useState(null);
@@ -46,7 +58,10 @@ const useTranscript = (
});
}, [id, !api]);
- return { response, loading, error };
+ return { response, loading, error } as
+ | ErrorTranscript
+ | LoadingTranscript
+ | SuccessTranscript;
};
export default useTranscript;
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index bcf6b163..cb74f86d 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -4,9 +4,10 @@ import { useError } from "../../(errors)/errorContext";
import { useRouter } from "next/navigation";
import { DomainContext } from "../domainContext";
-type UseWebSockets = {
+export type UseWebSockets = {
transcriptText: string;
translateText: string;
+ title: string;
topics: Topic[];
finalSummary: FinalSummary;
status: Status;
@@ -15,6 +16,7 @@ type UseWebSockets = {
export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
const [transcriptText, setTranscriptText] = useState("");
const [translateText, setTranslateText] = useState("");
+ const [title, setTitle] = useState("");
const [textQueue, setTextQueue] = useState([]);
const [translationQueue, setTranslationQueue] = useState([]);
const [isProcessing, setIsProcessing] = useState(false);
@@ -24,7 +26,6 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
});
const [status, setStatus] = useState({ value: "initial" });
const { setError } = useError();
- const router = useRouter();
const { websocket_url } = useContext(DomainContext);
@@ -294,7 +295,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
if (!transcriptId) return;
const url = `${websocket_url}/v1/transcripts/${transcriptId}/events`;
- const ws = new WebSocket(url);
+ let ws = new WebSocket(url);
ws.onopen = () => {
console.debug("WebSocket connection opened");
@@ -343,24 +344,23 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
case "FINAL_TITLE":
console.debug("FINAL_TITLE event:", message.data);
+ if (message.data) {
+ setTitle(message.data.title);
+ }
break;
case "STATUS":
console.log("STATUS event:", message.data);
- if (message.data.value === "ended") {
- const newUrl = "/transcripts/" + transcriptId;
- router.push(newUrl);
- console.debug("FINAL_LONG_SUMMARY event:", message.data);
- }
if (message.data.value === "error") {
- const newUrl = "/transcripts/" + transcriptId;
- router.push(newUrl);
setError(
Error("Websocket error status"),
"There was an error processing this meeting.",
);
}
setStatus(message.data);
+ if (message.data.value === "ended") {
+ ws.close();
+ }
break;
default:
@@ -388,8 +388,16 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
default:
setError(
new Error(`WebSocket closed unexpectedly with code: ${event.code}`),
+ "Disconnected",
);
}
+ console.log(
+ "Socket is closed. Reconnect will be attempted in 1 second.",
+ event.reason,
+ );
+ setTimeout(function () {
+ ws = new WebSocket(url);
+ }, 1000);
};
return () => {
@@ -397,5 +405,5 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
};
}, [transcriptId]);
- return { transcriptText, translateText, topics, finalSummary, status };
+ return { transcriptText, translateText, topics, finalSummary, title, status };
};
From e98f1bf4bc7bed06fddecd22537ea819761b83b7 Mon Sep 17 00:00:00 2001
From: Sara
Date: Mon, 13 Nov 2023 18:33:24 +0100
Subject: [PATCH 05/27] loading and redirecting front-end
---
.../transcripts/[transcriptId]/page.tsx | 18 +++++--
.../[transcriptId]/record/page.tsx | 50 +++++++++++++------
www/app/[domain]/transcripts/finalSummary.tsx | 2 +-
www/app/[domain]/transcripts/recorder.tsx | 13 +++--
www/app/[domain]/transcripts/useWebSockets.ts | 14 +++++-
.../[domain]/transcripts/waveformLoading.tsx | 11 ++++
6 files changed, 77 insertions(+), 31 deletions(-)
create mode 100644 www/app/[domain]/transcripts/waveformLoading.tsx
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index add5a824..079b41e8 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -13,6 +13,7 @@ import ShareLink from "../shareLink";
import QRCode from "react-qr-code";
import TranscriptTitle from "../transcriptTitle";
import Player from "../player";
+import WaveformLoading from "../waveformLoading";
type TranscriptDetails = {
params: {
@@ -83,9 +84,9 @@ export default function TranscriptDetails(details: TranscriptDetails) {
mediaDuration={transcript.response.duration}
/>
) : mp3.error || waveform.error ? (
- "error loading this recording"
+ "error loading this recording"
) : (
- "Loading Recording"
+
)}
@@ -104,10 +105,17 @@ export default function TranscriptDetails(details: TranscriptDetails) {
summary={transcript.response.longSummary}
transcriptId={transcript.response.id}
/>
- ) : transcript.response.status == "processing" ? (
- "Loading Transcript"
) : (
- "error final summary"
+
+ {transcript.response.status == "processing" ? (
+
Loading Transcript
+ ) : (
+
+ There was an error generating the final summary, please
+ come back later
+
+ )}
+
)}
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
index 10297f5c..36f4bbe9 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
@@ -11,9 +11,12 @@ import { Topic } from "../../webSocketTypes";
import LiveTrancription from "../../liveTranscription";
import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
-import { faGear } from "@fortawesome/free-solid-svg-icons";
+import { faGear, faSpinner } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
import { useRouter } from "next/navigation";
+import Player from "../../player";
+import useMp3 from "../../useMp3";
+import WaveformLoading from "../../waveformLoading";
type TranscriptDetails = {
params: {
@@ -42,8 +45,10 @@ const TranscriptRecord = (details: TranscriptDetails) => {
const { audioDevices, getAudioStream } = useAudioDevice();
- const [hasRecorded, setHasRecorded] = useState(false);
+ const [recordedTime, setRecordedTime] = useState(0);
+ const [startTime, setStartTime] = useState(0);
const [transcriptStarted, setTranscriptStarted] = useState(false);
+ const mp3 = useMp3(true, "");
const router = useRouter();
@@ -54,8 +59,6 @@ const TranscriptRecord = (details: TranscriptDetails) => {
useEffect(() => {
const statusToRedirect = ["ended", "error"];
- console.log(webSockets.status, "hey");
- console.log(transcript.response, "ho");
//TODO if has no topic and is error, get back to new
if (
@@ -80,16 +83,31 @@ const TranscriptRecord = (details: TranscriptDetails) => {
return (
<>
-
{
- setStream(null);
- setHasRecorded(true);
- webRTC?.send(JSON.stringify({ cmd: "STOP" }));
- }}
- getAudioStream={getAudioStream}
- audioDevices={audioDevices}
- />
+ {webSockets.waveform && mp3.media ? (
+
+ ) : recordedTime ? (
+
+ ) : (
+ {
+ setStream(null);
+ setRecordedTime(Date.now() - startTime);
+ webRTC?.send(JSON.stringify({ cmd: "STOP" }));
+ }}
+ onRecord={() => {
+ setStartTime(Date.now());
+ }}
+ getAudioStream={getAudioStream}
+ audioDevices={audioDevices}
+ />
+ )}
{
- {!hasRecorded ? (
+ {!recordedTime ? (
<>
{transcriptStarted && (
Transcription
@@ -135,7 +153,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
couple of minutes. Please do not navigate away from the page
during this time.
- {/* TODO If login required remove last sentence */}
+ {/* NTH If login required remove last sentence */}
)}
diff --git a/www/app/[domain]/transcripts/finalSummary.tsx b/www/app/[domain]/transcripts/finalSummary.tsx
index 463f6100..e0d0f1c9 100644
--- a/www/app/[domain]/transcripts/finalSummary.tsx
+++ b/www/app/[domain]/transcripts/finalSummary.tsx
@@ -87,7 +87,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx
index 5b4420c4..e7c016a7 100644
--- a/www/app/[domain]/transcripts/recorder.tsx
+++ b/www/app/[domain]/transcripts/recorder.tsx
@@ -14,11 +14,11 @@ import { waveSurferStyles } from "../../styles/recorder";
import { useError } from "../../(errors)/errorContext";
type RecorderProps = {
- setStream?: React.Dispatch
>;
- onStop?: () => void;
- getAudioStream?: (deviceId) => Promise;
- audioDevices?: Option[];
- mediaDuration?: number | null;
+ setStream: React.Dispatch>;
+ onStop: () => void;
+ onRecord?: () => void;
+ getAudioStream: (deviceId) => Promise;
+ audioDevices: Option[];
};
export default function Recorder(props: RecorderProps) {
@@ -94,7 +94,6 @@ export default function Recorder(props: RecorderProps) {
autoCenter: true,
barWidth: 2,
height: "auto",
- duration: props.mediaDuration || 1,
...waveSurferStyles.player,
});
@@ -161,10 +160,10 @@ export default function Recorder(props: RecorderProps) {
setScreenMediaStream(null);
setDestinationStream(null);
} else {
+ if (props.onRecord) props.onRecord();
const stream = await getCurrentStream();
if (props.setStream) props.setStream(stream);
- waveRegions?.clearRegions();
if (stream) {
await record.startRecording(stream);
setIsRecording(true);
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index cb74f86d..3f3d20fc 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -1,8 +1,8 @@
import { useContext, useEffect, useState } from "react";
import { Topic, FinalSummary, Status } from "./webSocketTypes";
import { useError } from "../../(errors)/errorContext";
-import { useRouter } from "next/navigation";
import { DomainContext } from "../domainContext";
+import { AudioWaveform } from "../../api";
export type UseWebSockets = {
transcriptText: string;
@@ -11,6 +11,7 @@ export type UseWebSockets = {
topics: Topic[];
finalSummary: FinalSummary;
status: Status;
+ waveform: AudioWaveform | null;
};
export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
@@ -21,6 +22,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
const [translationQueue, setTranslationQueue] = useState([]);
const [isProcessing, setIsProcessing] = useState(false);
const [topics, setTopics] = useState([]);
+ const [waveform, setWaveForm] = useState(null);
const [finalSummary, setFinalSummary] = useState({
summary: "",
});
@@ -405,5 +407,13 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
};
}, [transcriptId]);
- return { transcriptText, translateText, topics, finalSummary, title, status };
+ return {
+ transcriptText,
+ translateText,
+ topics,
+ finalSummary,
+ title,
+ status,
+ waveform,
+ };
};
diff --git a/www/app/[domain]/transcripts/waveformLoading.tsx b/www/app/[domain]/transcripts/waveformLoading.tsx
new file mode 100644
index 00000000..68e0c80f
--- /dev/null
+++ b/www/app/[domain]/transcripts/waveformLoading.tsx
@@ -0,0 +1,11 @@
+import { faSpinner } from "@fortawesome/free-solid-svg-icons";
+import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
+
+export default () => (
+
+
+
+);
From 1fc261a66999df014820f07836d6300977977fe4 Mon Sep 17 00:00:00 2001
From: Sara
Date: Wed, 15 Nov 2023 20:30:00 +0100
Subject: [PATCH 06/27] try to move waveform to pipeline
---
server/reflector/db/transcripts.py | 25 +++++---------
.../reflector/pipelines/main_live_pipeline.py | 26 +++++++++++++--
.../processors/audio_waveform_processor.py | 33 +++++++++++++++++++
server/reflector/views/transcripts.py | 3 +-
server/tests/test_transcripts_rtc_ws.py | 4 +++
www/app/lib/edgeConfig.ts | 4 +--
6 files changed, 72 insertions(+), 23 deletions(-)
create mode 100644 server/reflector/processors/audio_waveform_processor.py
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 6ac2e32a..f0dbc277 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
-from reflector.utils.audio_waveform import get_audio_waveform
transcripts = sqlalchemy.Table(
"transcript",
@@ -79,6 +78,14 @@ class TranscriptFinalTitle(BaseModel):
title: str
+class TranscriptDuration(BaseModel):
+ duration: float
+
+
+class TranscriptWaveform(BaseModel):
+ waveform: list[float]
+
+
class TranscriptEvent(BaseModel):
event: str
data: dict
@@ -118,22 +125,6 @@ class Transcript(BaseModel):
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
- def convert_audio_to_waveform(self, segments_count=256):
- fn = self.audio_waveform_filename
- if fn.exists():
- return
- waveform = get_audio_waveform(
- path=self.audio_mp3_filename, segments_count=segments_count
- )
- try:
- with open(fn, "w") as fd:
- json.dump(waveform, fd)
- except Exception:
- # remove file if anything happen during the write
- fn.unlink(missing_ok=True)
- raise
- return waveform
-
def unlink(self):
self.data_path.unlink(missing_ok=True)
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 8c78c48f..b0576b92 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -21,11 +21,13 @@ from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
+ TranscriptDuration,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
+ TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
@@ -45,6 +47,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
+from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -230,15 +233,29 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
- async def on_duration(self, duration: float):
+ async def on_duration(self, data):
async with self.transaction():
+ duration = TranscriptDuration(duration=data)
+
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
- "duration": duration,
+ "duration": duration.duration,
},
)
+ return await transcripts_controller.append_event(
+ transcript=transcript, event="DURATION", data=duration
+ )
+
+ async def on_waveform(self, data):
+ waveform = TranscriptWaveform(waveform=data)
+
+ transcript = await self.get_transcript()
+
+ return await transcripts_controller.append_event(
+ transcript=transcript, event="WAVEFORM", data=waveform
+ )
class PipelineMainLive(PipelineMainBase):
@@ -266,6 +283,11 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
+ AudioWaveformProcessor(
+ audio_path=transcript.audio_mp3_filename,
+ waveform_path=transcript.audio_waveform_filename,
+ on_waveform=self.on_waveform,
+ ),
]
),
]
diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py
new file mode 100644
index 00000000..acce904a
--- /dev/null
+++ b/server/reflector/processors/audio_waveform_processor.py
@@ -0,0 +1,33 @@
+import json
+from pathlib import Path
+
+from reflector.processors.base import Processor
+from reflector.processors.types import TitleSummary
+from reflector.utils.audio_waveform import get_audio_waveform
+
+
+class AudioWaveformProcessor(Processor):
+ """
+ Write the waveform for the final audio
+ """
+
+ INPUT_TYPE = TitleSummary
+
+ def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(audio_path, str):
+ audio_path = Path(audio_path)
+ if audio_path.suffix not in (".mp3", ".wav"):
+ raise ValueError("Only mp3 and wav files are supported")
+ self.audio_path = audio_path
+ self.waveform_path = waveform_path
+
+ async def _push(self, _data):
+ self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
+ self.logger.info("Waveform Processing Started")
+ waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
+
+ with open(self.waveform_path, "w") as fd:
+ json.dump(waveform, fd)
+ self.logger.info("Waveform Processing Finished")
+ await self.emit(waveform, name="waveform")
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 5de9ced3..77e4b0b7 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -22,7 +22,6 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
-from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
@@ -261,7 +260,7 @@ async def transcript_get_audio_waveform(
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
- await run_in_threadpool(transcript.convert_audio_to_waveform)
+ # await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index cf2ea304..65660a5e 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -182,6 +182,10 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
+ assert "WAVEFORM" in eventnames
+ ev = events[eventnames.index("FINAL_TITLE")]
+ assert ev["data"]["title"] == "LLM TITLE"
+
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing")
diff --git a/www/app/lib/edgeConfig.ts b/www/app/lib/edgeConfig.ts
index 5527121a..1140e555 100644
--- a/www/app/lib/edgeConfig.ts
+++ b/www/app/lib/edgeConfig.ts
@@ -3,9 +3,9 @@ import { isDevelopment } from "./utils";
const localConfig = {
features: {
- requireLogin: true,
+ requireLogin: false,
privacy: true,
- browse: true,
+ browse: false,
},
api_url: "http://127.0.0.1:1250",
websocket_url: "ws://127.0.0.1:1250",
From 8b8e92ceac7a121bd16909ea2d8401a889ae44cb Mon Sep 17 00:00:00 2001
From: Sara
Date: Wed, 15 Nov 2023 20:34:45 +0100
Subject: [PATCH 07/27] replace history instead of pushing
---
www/app/[domain]/transcripts/[transcriptId]/record/page.tsx | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
index 36f4bbe9..147f8827 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
@@ -69,9 +69,9 @@ const TranscriptRecord = (details: TranscriptDetails) => {
// Shallow redirection does not work on NextJS 13
// https://github.com/vercel/next.js/discussions/48110
// https://github.com/vercel/next.js/discussions/49540
- router.push(newUrl, undefined);
+ router.replace(newUrl);
// history.replaceState({}, "", newUrl);
- }
+ } // history.replaceState({}, "", newUrl);
}, [webSockets.status.value, transcript.response?.status]);
useEffect(() => {
From a846e38fbdeb77aad23282a8068b08e4cd6b3921 Mon Sep 17 00:00:00 2001
From: Sara
Date: Fri, 17 Nov 2023 13:38:32 +0100
Subject: [PATCH 08/27] fix waveform in pipeline
---
.../reflector/pipelines/main_live_pipeline.py | 15 +++++++++------
.../processors/audio_waveform_processor.py | 5 ++++-
server/tests/test_transcripts_audio_download.py | 12 ------------
server/tests/test_transcripts_rtc_ws.py | 17 ++++++++++++-----
4 files changed, 25 insertions(+), 24 deletions(-)
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index b0576b92..fece6da5 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -233,6 +233,7 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
+ @broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction():
duration = TranscriptDuration(duration=data)
@@ -248,14 +249,16 @@ class PipelineMainBase(PipelineRunner):
transcript=transcript, event="DURATION", data=duration
)
+ @broadcast_to_sockets
async def on_waveform(self, data):
- waveform = TranscriptWaveform(waveform=data)
+ async with self.transaction():
+ waveform = TranscriptWaveform(waveform=data)
- transcript = await self.get_transcript()
+ transcript = await self.get_transcript()
- return await transcripts_controller.append_event(
- transcript=transcript, event="WAVEFORM", data=waveform
- )
+ return await transcripts_controller.append_event(
+ transcript=transcript, event="WAVEFORM", data=waveform
+ )
class PipelineMainLive(PipelineMainBase):
@@ -283,7 +286,7 @@ class PipelineMainLive(PipelineMainBase):
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
- AudioWaveformProcessor(
+ AudioWaveformProcessor.as_threaded(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py
index acce904a..f1a24ffd 100644
--- a/server/reflector/processors/audio_waveform_processor.py
+++ b/server/reflector/processors/audio_waveform_processor.py
@@ -22,7 +22,7 @@ class AudioWaveformProcessor(Processor):
self.audio_path = audio_path
self.waveform_path = waveform_path
- async def _push(self, _data):
+ async def _flush(self):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
@@ -31,3 +31,6 @@ class AudioWaveformProcessor(Processor):
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")
+
+ async def _push(_self, _data):
+ return
diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py
index 69ae5f65..28f83fff 100644
--- a/server/tests/test_transcripts_audio_download.py
+++ b/server/tests/test_transcripts_audio_download.py
@@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")
-
-
-@pytest.mark.asyncio
-async def test_transcript_audio_download_waveform(fake_transcript):
- from reflector.app import app
-
- ac = AsyncClient(app=app, base_url="http://test/v1")
- response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
- assert response.status_code == 200
- assert response.headers["content-type"] == "application/json"
- assert isinstance(response.json()["data"], list)
- assert len(response.json()["data"]) >= 255
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index 65660a5e..b33b1db5 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -183,8 +183,14 @@ async def test_transcript_rtc_and_websocket(
assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
- ev = events[eventnames.index("FINAL_TITLE")]
- assert ev["data"]["title"] == "LLM TITLE"
+ ev = events[eventnames.index("WAVEFORM")]
+ assert isinstance(ev["data"]["waveform"], list)
+ assert len(ev["data"]["waveform"]) >= 250
+ waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
+ assert waveform_resp.status_code == 200
+ assert waveform_resp.headers["content-type"] == "application/json"
+ assert isinstance(waveform_resp.json()["data"], list)
+ assert len(waveform_resp.json()["data"]) >= 250
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
@@ -197,11 +203,12 @@ async def test_transcript_rtc_and_websocket(
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
+ assert "DURATION" in eventnames
# check that audio/mp3 is available
- resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
- assert resp.status_code == 200
- assert resp.headers["Content-Type"] == "audio/mpeg"
+ audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
+ assert audio_resp.status_code == 200
+ assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")
From a816e00769faba40098acd36085469a23089f614 Mon Sep 17 00:00:00 2001
From: Sara
Date: Fri, 17 Nov 2023 15:16:53 +0100
Subject: [PATCH 09/27] get waveform from socket
---
.../transcripts/[transcriptId]/page.tsx | 7 ++---
.../[transcriptId]/record/page.tsx | 16 ++++++----
www/app/[domain]/transcripts/player.tsx | 5 ++--
www/app/[domain]/transcripts/useMp3.ts | 29 ++++++++-----------
www/app/[domain]/transcripts/useWebSockets.ts | 21 +++++++++++++-
5 files changed, 48 insertions(+), 30 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 079b41e8..b0986e94 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -30,7 +30,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const topics = useTopics(protectedPath, transcriptId);
const waveform = useWaveform(protectedPath, transcriptId);
const useActiveTopic = useState(null);
- const mp3 = useMp3(protectedPath, transcriptId);
+ const mp3 = useMp3(transcriptId);
if (transcript?.error || topics?.error) {
return (
@@ -59,7 +59,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
.join("\n\n")
.replace(/ +/g, " ")
.trim() || "";
- console.log("calf full transcript");
return (
<>
@@ -79,11 +78,11 @@ export default function TranscriptDetails(details: TranscriptDetails) {
- ) : mp3.error || waveform.error ? (
+ ) : waveform.error ? (
"error loading this recording"
) : (
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
index 147f8827..2c5b73e0 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
@@ -11,11 +11,11 @@ import { Topic } from "../../webSocketTypes";
import LiveTrancription from "../../liveTranscription";
import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
-import { faGear, faSpinner } from "@fortawesome/free-solid-svg-icons";
+import { faGear } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
import { useRouter } from "next/navigation";
import Player from "../../player";
-import useMp3 from "../../useMp3";
+import useMp3, { Mp3Response } from "../../useMp3";
import WaveformLoading from "../../waveformLoading";
type TranscriptDetails = {
@@ -48,7 +48,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
const [recordedTime, setRecordedTime] = useState(0);
const [startTime, setStartTime] = useState(0);
const [transcriptStarted, setTranscriptStarted] = useState(false);
- const mp3 = useMp3(true, "");
+ let mp3 = useMp3(details.params.transcriptId, true);
const router = useRouter();
@@ -74,6 +74,12 @@ const TranscriptRecord = (details: TranscriptDetails) => {
} // history.replaceState({}, "", newUrl);
}, [webSockets.status.value, transcript.response?.status]);
+ useEffect(() => {
+ if (webSockets.duration) {
+ mp3.getNow();
+ }
+ }, [webSockets.duration]);
+
useEffect(() => {
lockWakeState();
return () => {
@@ -83,13 +89,13 @@ const TranscriptRecord = (details: TranscriptDetails) => {
return (
<>
- {webSockets.waveform && mp3.media ? (
+ {webSockets.waveform && webSockets.duration && mp3?.media ? (
) : recordedTime ? (
diff --git a/www/app/[domain]/transcripts/player.tsx b/www/app/[domain]/transcripts/player.tsx
index 6143836e..02151a68 100644
--- a/www/app/[domain]/transcripts/player.tsx
+++ b/www/app/[domain]/transcripts/player.tsx
@@ -14,7 +14,7 @@ type PlayerProps = {
Topic | null,
React.Dispatch>,
];
- waveform: AudioWaveform;
+ waveform: AudioWaveform["data"];
media: HTMLMediaElement;
mediaDuration: number;
};
@@ -29,7 +29,6 @@ export default function Player(props: PlayerProps) {
);
const [activeTopic, setActiveTopic] = props.useActiveTopic;
const topicsRef = useRef(props.topics);
-
// Waveform setup
useEffect(() => {
if (waveformRef.current) {
@@ -40,7 +39,7 @@ export default function Player(props: PlayerProps) {
// This is not ideal, but it works for now.
const _wavesurfer = WaveSurfer.create({
container: waveformRef.current,
- peaks: props.waveform.data,
+ peaks: props.waveform,
hideScrollbar: true,
autoCenter: true,
barWidth: 2,
diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts
index 570a6a25..23249f94 100644
--- a/www/app/[domain]/transcripts/useMp3.ts
+++ b/www/app/[domain]/transcripts/useMp3.ts
@@ -1,24 +1,19 @@
import { useContext, useEffect, useState } from "react";
-import { useError } from "../../(errors)/errorContext";
import { DomainContext } from "../domainContext";
import getApi from "../../lib/getApi";
import { useFiefAccessTokenInfo } from "@fief/fief/build/esm/nextjs/react";
-import { shouldShowError } from "../../lib/errorUtils";
-type Mp3Response = {
- url: string | null;
+export type Mp3Response = {
media: HTMLMediaElement | null;
loading: boolean;
- error: Error | null;
+ getNow: () => void;
};
-const useMp3 = (protectedPath: boolean, id: string): Mp3Response => {
- const [url, setUrl] = useState(null);
+const useMp3 = (id: string, waiting?: boolean): Mp3Response => {
const [media, setMedia] = useState(null);
+ const [later, setLater] = useState(waiting);
const [loading, setLoading] = useState(false);
- const [error, setErrorState] = useState(null);
- const { setError } = useError();
- const api = getApi(protectedPath);
+ const api = getApi(true);
const { api_url } = useContext(DomainContext);
const accessTokenInfo = useFiefAccessTokenInfo();
const [serviceWorkerReady, setServiceWorkerReady] = useState(false);
@@ -42,8 +37,8 @@ const useMp3 = (protectedPath: boolean, id: string): Mp3Response => {
});
}, [navigator.serviceWorker, serviceWorkerReady, accessTokenInfo]);
- const getMp3 = (id: string) => {
- if (!id || !api) return;
+ useEffect(() => {
+ if (!id || !api || later) return;
// createa a audio element and set the source
setLoading(true);
@@ -53,13 +48,13 @@ const useMp3 = (protectedPath: boolean, id: string): Mp3Response => {
audioElement.preload = "auto";
setMedia(audioElement);
setLoading(false);
+ }, [id, api, later]);
+
+ const getNow = () => {
+ setLater(false);
};
- useEffect(() => {
- getMp3(id);
- }, [id, api]);
-
- return { url, media, loading, error };
+ return { media, loading, getNow };
};
export default useMp3;
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index 3f3d20fc..f33a1347 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -11,7 +11,8 @@ export type UseWebSockets = {
topics: Topic[];
finalSummary: FinalSummary;
status: Status;
- waveform: AudioWaveform | null;
+ waveform: AudioWaveform["data"] | null;
+ duration: number | null;
};
export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
@@ -23,6 +24,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
const [isProcessing, setIsProcessing] = useState(false);
const [topics, setTopics] = useState([]);
const [waveform, setWaveForm] = useState(null);
+ const [duration, setDuration] = useState(null);
const [finalSummary, setFinalSummary] = useState({
summary: "",
});
@@ -351,6 +353,22 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
}
break;
+ case "WAVEFORM":
+ console.debug(
+ "WAVEFORM event length:",
+ message.data.waveform.length,
+ );
+ if (message.data) {
+ setWaveForm(message.data.waveform);
+ }
+ break;
+ case "DURATION":
+ console.debug("DURATION event:", message.data);
+ if (message.data) {
+ setDuration(message.data.duration);
+ }
+ break;
+
case "STATUS":
console.log("STATUS event:", message.data);
if (message.data.value === "error") {
@@ -415,5 +433,6 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
title,
status,
waveform,
+ duration,
};
};
From c40e0970b7d9dfe4d3df2cb159b7ffbecfece077 Mon Sep 17 00:00:00 2001
From: Sara
Date: Mon, 20 Nov 2023 20:13:18 +0100
Subject: [PATCH 10/27] review fixes
---
server/reflector/views/transcripts.py | 2 --
www/app/[domain]/transcripts/useWebSockets.ts | 17 +++++++----------
2 files changed, 7 insertions(+), 12 deletions(-)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 77e4b0b7..6909b8ae 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -260,8 +260,6 @@ async def transcript_get_audio_waveform(
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
- # await run_in_threadpool(transcript.convert_audio_to_waveform)
-
return transcript.audio_waveform
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index f33a1347..f289adbb 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -402,22 +402,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
console.debug("WebSocket connection closed");
switch (event.code) {
case 1000: // Normal Closure:
- case 1001: // Going Away:
- case 1005:
- break;
default:
setError(
new Error(`WebSocket closed unexpectedly with code: ${event.code}`),
"Disconnected",
);
+ console.log(
+ "Socket is closed. Reconnect will be attempted in 1 second.",
+ event.reason,
+ );
+ setTimeout(function () {
+ ws = new WebSocket(url);
+ }, 1000);
}
- console.log(
- "Socket is closed. Reconnect will be attempted in 1 second.",
- event.reason,
- );
- setTimeout(function () {
- ws = new WebSocket(url);
- }, 1000);
};
return () => {
From f38dad3ad400830f5c35c78990ecce4142ad78ab Mon Sep 17 00:00:00 2001
From: Sara
Date: Wed, 22 Nov 2023 12:25:21 +0100
Subject: [PATCH 11/27] small fixes and start auth fix
---
www/app/[domain]/layout.tsx | 143 ++++++++--------
.../transcripts/[transcriptId]/page.tsx | 162 +++++++++---------
www/app/[domain]/transcripts/useMp3.ts | 14 +-
www/app/[domain]/transcripts/useWaveform.ts | 5 +-
www/app/[domain]/transcripts/useWebSockets.ts | 1 +
www/app/lib/edgeConfig.ts | 4 +-
www/app/lib/errorUtils.ts | 5 +-
www/app/lib/fief.ts | 2 +-
8 files changed, 172 insertions(+), 164 deletions(-)
diff --git a/www/app/[domain]/layout.tsx b/www/app/[domain]/layout.tsx
index dbe5ed11..3e881ac3 100644
--- a/www/app/[domain]/layout.tsx
+++ b/www/app/[domain]/layout.tsx
@@ -11,6 +11,7 @@ import About from "../(aboutAndPrivacy)/about";
import Privacy from "../(aboutAndPrivacy)/privacy";
import { DomainContextProvider } from "./domainContext";
import { getConfig } from "../lib/edgeConfig";
+import { ErrorBoundary } from "@sentry/nextjs";
const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] });
@@ -76,80 +77,82 @@ export default async function RootLayout({ children, params }: LayoutProps) {
-
-
-
+
+
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 734f2609..23634b95 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -14,6 +14,7 @@ import QRCode from "react-qr-code";
import TranscriptTitle from "../transcriptTitle";
import Player from "../player";
import WaveformLoading from "../waveformLoading";
+import { useRouter } from "next/navigation";
type TranscriptDetails = {
params: {
@@ -21,10 +22,11 @@ type TranscriptDetails = {
};
};
-const protectedPath = true;
+const protectedPath = false;
export default function TranscriptDetails(details: TranscriptDetails) {
const transcriptId = details.params.transcriptId;
+ const router = useRouter();
const transcript = useTranscript(protectedPath, transcriptId);
const topics = useTopics(protectedPath, transcriptId);
@@ -32,15 +34,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const useActiveTopic = useState(null);
const mp3 = useMp3(transcriptId);
- if (transcript?.error || topics?.error) {
- return (
-
- );
- }
-
useEffect(() => {
const statusToRedirect = ["idle", "recording", "processing"];
if (statusToRedirect.includes(transcript.response?.status)) {
@@ -48,8 +41,8 @@ export default function TranscriptDetails(details: TranscriptDetails) {
// Shallow redirection does not work on NextJS 13
// https://github.com/vercel/next.js/discussions/48110
// https://github.com/vercel/next.js/discussions/49540
- // router.push(newUrl, undefined, { shallow: true });
- history.replaceState({}, "", newUrl);
+ router.push(newUrl, undefined);
+ // history.replaceState({}, "", newUrl);
}
}, [transcript.response?.status]);
@@ -60,85 +53,92 @@ export default function TranscriptDetails(details: TranscriptDetails) {
.replace(/ +/g, " ")
.trim() || "";
+ if (transcript.error || topics?.error) {
+ return (
+
+ );
+ }
+
+ if (transcript?.loading || topics?.loading) {
+ return ;
+ }
+
return (
<>
- {transcript?.loading || topics?.loading ? (
-
- ) : (
- <>
-
- {transcript?.response?.title && (
-
+ {transcript?.response?.title && (
+
+ )}
+ {waveform.waveform && mp3.media ? (
+
+ ) : waveform.error ? (
+ "error loading this recording"
+ ) : (
+
+ )}
+
+
+
+
+
+
+ {transcript.response.longSummary ? (
+
- )}
- {waveform.waveform && mp3.media ? (
-
- ) : waveform.error ? (
- "error loading this recording"
) : (
-
- )}
-
-
-
-
-
-
- {transcript.response.longSummary ? (
-
+
+ {transcript.response.status == "processing" ? (
+
Loading Transcript
) : (
-
- {transcript.response.status == "processing" ? (
-
Loading Transcript
- ) : (
-
- There was an error generating the final summary, please
- come back later
-
- )}
-
+
+ There was an error generating the final summary, please come
+ back later
+
)}
-
+
+ )}
+
-
+
- >
- )}
+
+
+
+
+
+
>
);
}
diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts
index 23249f94..58e0209d 100644
--- a/www/app/[domain]/transcripts/useMp3.ts
+++ b/www/app/[domain]/transcripts/useMp3.ts
@@ -16,26 +16,30 @@ const useMp3 = (id: string, waiting?: boolean): Mp3Response => {
const api = getApi(true);
const { api_url } = useContext(DomainContext);
const accessTokenInfo = useFiefAccessTokenInfo();
- const [serviceWorkerReady, setServiceWorkerReady] = useState(false);
+ const [serviceWorker, setServiceWorker] =
+ useState(null);
useEffect(() => {
if ("serviceWorker" in navigator) {
- navigator.serviceWorker.register("/service-worker.js").then(() => {
- setServiceWorkerReady(true);
+ navigator.serviceWorker.register("/service-worker.js").then((worker) => {
+ setServiceWorker(worker);
});
}
+ return () => {
+ serviceWorker?.unregister();
+ };
}, []);
useEffect(() => {
if (!navigator.serviceWorker) return;
if (!navigator.serviceWorker.controller) return;
- if (!serviceWorkerReady) return;
+ if (!serviceWorker) return;
// Send the token to the service worker
navigator.serviceWorker.controller.postMessage({
type: "SET_AUTH_TOKEN",
token: accessTokenInfo?.access_token,
});
- }, [navigator.serviceWorker, serviceWorkerReady, accessTokenInfo]);
+ }, [navigator.serviceWorker, !serviceWorker, accessTokenInfo]);
useEffect(() => {
if (!id || !api || later) return;
diff --git a/www/app/[domain]/transcripts/useWaveform.ts b/www/app/[domain]/transcripts/useWaveform.ts
index 4073b711..d2bd0fd6 100644
--- a/www/app/[domain]/transcripts/useWaveform.ts
+++ b/www/app/[domain]/transcripts/useWaveform.ts
@@ -1,8 +1,5 @@
import { useEffect, useState } from "react";
-import {
- DefaultApi,
- V1TranscriptGetAudioWaveformRequest,
-} from "../../api/apis/DefaultApi";
+import { V1TranscriptGetAudioWaveformRequest } from "../../api/apis/DefaultApi";
import { AudioWaveform } from "../../api";
import { useError } from "../../(errors)/errorContext";
import getApi from "../../lib/getApi";
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index f289adbb..1e59781c 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -402,6 +402,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
console.debug("WebSocket connection closed");
switch (event.code) {
case 1000: // Normal Closure:
+ case 1005: // Closure by client FF
default:
setError(
new Error(`WebSocket closed unexpectedly with code: ${event.code}`),
diff --git a/www/app/lib/edgeConfig.ts b/www/app/lib/edgeConfig.ts
index 1140e555..5527121a 100644
--- a/www/app/lib/edgeConfig.ts
+++ b/www/app/lib/edgeConfig.ts
@@ -3,9 +3,9 @@ import { isDevelopment } from "./utils";
const localConfig = {
features: {
- requireLogin: false,
+ requireLogin: true,
privacy: true,
- browse: false,
+ browse: true,
},
api_url: "http://127.0.0.1:1250",
websocket_url: "ws://127.0.0.1:1250",
diff --git a/www/app/lib/errorUtils.ts b/www/app/lib/errorUtils.ts
index 81a39b5d..e9e5300d 100644
--- a/www/app/lib/errorUtils.ts
+++ b/www/app/lib/errorUtils.ts
@@ -1,5 +1,8 @@
function shouldShowError(error: Error | null | undefined) {
- if (error?.name == "ResponseError" && error["response"].status == 404)
+ if (
+ error?.name == "ResponseError" &&
+ (error["response"].status == 404 || error["response"].status == 403)
+ )
return false;
if (error?.name == "FetchError") return false;
return true;
diff --git a/www/app/lib/fief.ts b/www/app/lib/fief.ts
index 02db67f5..176aa847 100644
--- a/www/app/lib/fief.ts
+++ b/www/app/lib/fief.ts
@@ -67,7 +67,7 @@ export const getFiefAuthMiddleware = async (url) => {
parameters: {},
},
{
- matcher: "/transcripts/((?!new).*)",
+ matcher: "/transcripts/((?!new))",
parameters: {},
},
{
From f14e6f5a7f9e137ff9d43cf87e25c3ecd32688d4 Mon Sep 17 00:00:00 2001
From: Sara
Date: Wed, 22 Nov 2023 13:20:11 +0100
Subject: [PATCH 12/27] fix auth
---
www/app/(auth)/fiefWrapper.tsx | 15 +++++++++++----
www/app/[domain]/layout.tsx | 5 ++++-
.../[domain]/transcripts/[transcriptId]/page.tsx | 11 +++--------
.../transcripts/[transcriptId]/record/page.tsx | 6 +++---
www/app/[domain]/transcripts/createTranscript.ts | 2 +-
www/app/[domain]/transcripts/finalSummary.tsx | 3 +--
www/app/[domain]/transcripts/shareLink.tsx | 3 +--
www/app/[domain]/transcripts/transcriptTitle.tsx | 3 +--
www/app/[domain]/transcripts/useMp3.ts | 2 +-
www/app/[domain]/transcripts/useTopics.ts | 4 ++--
www/app/[domain]/transcripts/useTranscript.ts | 3 +--
www/app/[domain]/transcripts/useTranscriptList.ts | 2 +-
www/app/[domain]/transcripts/useWaveform.ts | 4 ++--
www/app/[domain]/transcripts/useWebRTC.ts | 3 +--
www/app/lib/fief.ts | 4 ----
www/app/lib/getApi.ts | 8 +++++---
16 files changed, 38 insertions(+), 40 deletions(-)
diff --git a/www/app/(auth)/fiefWrapper.tsx b/www/app/(auth)/fiefWrapper.tsx
index 187fef7c..bb38f5ee 100644
--- a/www/app/(auth)/fiefWrapper.tsx
+++ b/www/app/(auth)/fiefWrapper.tsx
@@ -1,11 +1,18 @@
"use client";
import { FiefAuthProvider } from "@fief/fief/nextjs/react";
+import { createContext } from "react";
-export default function FiefWrapper({ children }) {
+export const CookieContext = createContext<{ hasAuthCookie: boolean }>({
+ hasAuthCookie: false,
+});
+
+export default function FiefWrapper({ children, hasAuthCookie }) {
return (
-
- {children}
-
+
+
+ {children}
+
+
);
}
diff --git a/www/app/[domain]/layout.tsx b/www/app/[domain]/layout.tsx
index 3e881ac3..73cc4841 100644
--- a/www/app/[domain]/layout.tsx
+++ b/www/app/[domain]/layout.tsx
@@ -12,6 +12,8 @@ import Privacy from "../(aboutAndPrivacy)/privacy";
import { DomainContextProvider } from "./domainContext";
import { getConfig } from "../lib/edgeConfig";
import { ErrorBoundary } from "@sentry/nextjs";
+import { cookies } from "next/dist/client/components/headers";
+import { SESSION_COOKIE_NAME } from "../lib/fief";
const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] });
@@ -71,11 +73,12 @@ type LayoutProps = {
export default async function RootLayout({ children, params }: LayoutProps) {
const config = await getConfig(params.domain);
const { requireLogin, privacy, browse } = config.features;
+ const hasAuthCookie = !!cookies().get(SESSION_COOKIE_NAME);
return (
-
+
"something went really wrong"}>
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 23634b95..472c573b 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -22,15 +22,13 @@ type TranscriptDetails = {
};
};
-const protectedPath = false;
-
export default function TranscriptDetails(details: TranscriptDetails) {
const transcriptId = details.params.transcriptId;
const router = useRouter();
- const transcript = useTranscript(protectedPath, transcriptId);
- const topics = useTopics(protectedPath, transcriptId);
- const waveform = useWaveform(protectedPath, transcriptId);
+ const transcript = useTranscript(transcriptId);
+ const topics = useTopics(transcriptId);
+ const waveform = useWaveform(transcriptId);
const useActiveTopic = useState(null);
const mp3 = useMp3(transcriptId);
@@ -71,7 +69,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
{transcript?.response?.title && (
@@ -101,7 +98,6 @@ export default function TranscriptDetails(details: TranscriptDetails) {
{transcript.response.longSummary ? (
{
}
}, []);
- const transcript = useTranscript(true, details.params.transcriptId);
- const webRTC = useWebRTC(stream, details.params.transcriptId, true);
+ const transcript = useTranscript(details.params.transcriptId);
+ const webRTC = useWebRTC(stream, details.params.transcriptId);
const webSockets = useWebSockets(details.params.transcriptId);
const { audioDevices, getAudioStream } = useAudioDevice();
diff --git a/www/app/[domain]/transcripts/createTranscript.ts b/www/app/[domain]/transcripts/createTranscript.ts
index 0d96b8db..9ad1abe0 100644
--- a/www/app/[domain]/transcripts/createTranscript.ts
+++ b/www/app/[domain]/transcripts/createTranscript.ts
@@ -19,7 +19,7 @@ const useCreateTranscript = (): CreateTranscript => {
const [loading, setLoading] = useState(false);
const [error, setErrorState] = useState(null);
const { setError } = useError();
- const api = getApi(true);
+ const api = getApi();
const create = (params: V1TranscriptsCreateRequest["createTranscript"]) => {
if (loading || !api) return;
diff --git a/www/app/[domain]/transcripts/finalSummary.tsx b/www/app/[domain]/transcripts/finalSummary.tsx
index e0d0f1c9..47c757bf 100644
--- a/www/app/[domain]/transcripts/finalSummary.tsx
+++ b/www/app/[domain]/transcripts/finalSummary.tsx
@@ -5,7 +5,6 @@ import "../../styles/markdown.css";
import getApi from "../../lib/getApi";
type FinalSummaryProps = {
- protectedPath: boolean;
summary: string;
fullTranscript: string;
transcriptId: string;
@@ -18,7 +17,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
const [isEditMode, setIsEditMode] = useState(false);
const [preEditSummary, setPreEditSummary] = useState(props.summary);
const [editedSummary, setEditedSummary] = useState(props.summary);
- const api = getApi(props.protectedPath);
+ const api = getApi();
const updateSummary = async (newSummary: string, transcriptId: string) => {
if (!api) return;
diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx
index 82ef52c9..e6449bd3 100644
--- a/www/app/[domain]/transcripts/shareLink.tsx
+++ b/www/app/[domain]/transcripts/shareLink.tsx
@@ -8,7 +8,6 @@ import "../../styles/button.css";
import "../../styles/form.scss";
type ShareLinkProps = {
- protectedPath: boolean;
transcriptId: string;
userId: string | null;
shareMode: string;
@@ -21,8 +20,8 @@ const ShareLink = (props: ShareLinkProps) => {
const requireLogin = featureEnabled("requireLogin");
const [isOwner, setIsOwner] = useState(false);
const [shareMode, setShareMode] = useState(props.shareMode);
- const api = getApi(props.protectedPath);
const userinfo = useFiefUserinfo();
+ const api = getApi();
useEffect(() => {
setCurrentUrl(window.location.href);
diff --git a/www/app/[domain]/transcripts/transcriptTitle.tsx b/www/app/[domain]/transcripts/transcriptTitle.tsx
index d2f901fa..afc29e51 100644
--- a/www/app/[domain]/transcripts/transcriptTitle.tsx
+++ b/www/app/[domain]/transcripts/transcriptTitle.tsx
@@ -2,7 +2,6 @@ import { useState } from "react";
import getApi from "../../lib/getApi";
type TranscriptTitle = {
- protectedPath: boolean;
title: string;
transcriptId: string;
};
@@ -11,7 +10,7 @@ const TranscriptTitle = (props: TranscriptTitle) => {
const [displayedTitle, setDisplayedTitle] = useState(props.title);
const [preEditTitle, setPreEditTitle] = useState(props.title);
const [isEditing, setIsEditing] = useState(false);
- const api = getApi(props.protectedPath);
+ const api = getApi();
const updateTitle = async (newTitle: string, transcriptId: string) => {
if (!api) return;
diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts
index 58e0209d..363a4190 100644
--- a/www/app/[domain]/transcripts/useMp3.ts
+++ b/www/app/[domain]/transcripts/useMp3.ts
@@ -13,7 +13,7 @@ const useMp3 = (id: string, waiting?: boolean): Mp3Response => {
const [media, setMedia] = useState(null);
const [later, setLater] = useState(waiting);
const [loading, setLoading] = useState(false);
- const api = getApi(true);
+ const api = getApi();
const { api_url } = useContext(DomainContext);
const accessTokenInfo = useFiefAccessTokenInfo();
const [serviceWorker, setServiceWorker] =
diff --git a/www/app/[domain]/transcripts/useTopics.ts b/www/app/[domain]/transcripts/useTopics.ts
index 01053019..de4097b3 100644
--- a/www/app/[domain]/transcripts/useTopics.ts
+++ b/www/app/[domain]/transcripts/useTopics.ts
@@ -14,12 +14,12 @@ type TranscriptTopics = {
error: Error | null;
};
-const useTopics = (protectedPath, id: string): TranscriptTopics => {
+const useTopics = (id: string): TranscriptTopics => {
const [topics, setTopics] = useState(null);
const [loading, setLoading] = useState(false);
const [error, setErrorState] = useState(null);
const { setError } = useError();
- const api = getApi(protectedPath);
+ const api = getApi();
useEffect(() => {
if (!id || !api) return;
diff --git a/www/app/[domain]/transcripts/useTranscript.ts b/www/app/[domain]/transcripts/useTranscript.ts
index 987e57f3..91700d7a 100644
--- a/www/app/[domain]/transcripts/useTranscript.ts
+++ b/www/app/[domain]/transcripts/useTranscript.ts
@@ -24,14 +24,13 @@ type SuccessTranscript = {
};
const useTranscript = (
- protectedPath: boolean,
id: string | null,
): ErrorTranscript | LoadingTranscript | SuccessTranscript => {
const [response, setResponse] = useState(null);
const [loading, setLoading] = useState(true);
const [error, setErrorState] = useState(null);
const { setError } = useError();
- const api = getApi(protectedPath);
+ const api = getApi();
useEffect(() => {
if (!id || !api) return;
diff --git a/www/app/[domain]/transcripts/useTranscriptList.ts b/www/app/[domain]/transcripts/useTranscriptList.ts
index cc8f4701..7b5abb37 100644
--- a/www/app/[domain]/transcripts/useTranscriptList.ts
+++ b/www/app/[domain]/transcripts/useTranscriptList.ts
@@ -15,7 +15,7 @@ const useTranscriptList = (page: number): TranscriptList => {
const [loading, setLoading] = useState(true);
const [error, setErrorState] = useState(null);
const { setError } = useError();
- const api = getApi(true);
+ const api = getApi();
useEffect(() => {
if (!api) return;
diff --git a/www/app/[domain]/transcripts/useWaveform.ts b/www/app/[domain]/transcripts/useWaveform.ts
index d2bd0fd6..f80ad78c 100644
--- a/www/app/[domain]/transcripts/useWaveform.ts
+++ b/www/app/[domain]/transcripts/useWaveform.ts
@@ -11,12 +11,12 @@ type AudioWaveFormResponse = {
error: Error | null;
};
-const useWaveform = (protectedPath, id: string): AudioWaveFormResponse => {
+const useWaveform = (id: string): AudioWaveFormResponse => {
const [waveform, setWaveform] = useState(null);
const [loading, setLoading] = useState(true);
const [error, setErrorState] = useState(null);
const { setError } = useError();
- const api = getApi(protectedPath);
+ const api = getApi();
useEffect(() => {
if (!id || !api) return;
diff --git a/www/app/[domain]/transcripts/useWebRTC.ts b/www/app/[domain]/transcripts/useWebRTC.ts
index f4421e4d..edd3bef0 100644
--- a/www/app/[domain]/transcripts/useWebRTC.ts
+++ b/www/app/[domain]/transcripts/useWebRTC.ts
@@ -10,11 +10,10 @@ import getApi from "../../lib/getApi";
const useWebRTC = (
stream: MediaStream | null,
transcriptId: string | null,
- protectedPath,
): Peer => {
const [peer, setPeer] = useState(null);
const { setError } = useError();
- const api = getApi(protectedPath);
+ const api = getApi();
useEffect(() => {
if (!stream || !transcriptId) {
diff --git a/www/app/lib/fief.ts b/www/app/lib/fief.ts
index 176aa847..3af5c30f 100644
--- a/www/app/lib/fief.ts
+++ b/www/app/lib/fief.ts
@@ -66,10 +66,6 @@ export const getFiefAuthMiddleware = async (url) => {
matcher: "/transcripts",
parameters: {},
},
- {
- matcher: "/transcripts/((?!new))",
- parameters: {},
- },
{
matcher: "/browse",
parameters: {},
diff --git a/www/app/lib/getApi.ts b/www/app/lib/getApi.ts
index 7392cc90..e1ece2a9 100644
--- a/www/app/lib/getApi.ts
+++ b/www/app/lib/getApi.ts
@@ -4,17 +4,19 @@ import { DefaultApi } from "../api/apis/DefaultApi";
import { useFiefAccessTokenInfo } from "@fief/fief/nextjs/react";
import { useContext, useEffect, useState } from "react";
import { DomainContext, featureEnabled } from "../[domain]/domainContext";
+import { CookieContext } from "../(auth)/fiefWrapper";
-export default function getApi(protectedPath: boolean): DefaultApi | undefined {
+export default function getApi(): DefaultApi | undefined {
const accessTokenInfo = useFiefAccessTokenInfo();
const api_url = useContext(DomainContext).api_url;
const requireLogin = featureEnabled("requireLogin");
const [api, setApi] = useState();
+ const { hasAuthCookie } = useContext(CookieContext);
if (!api_url) throw new Error("no API URL");
useEffect(() => {
- if (protectedPath && requireLogin && !accessTokenInfo) {
+ if (hasAuthCookie && requireLogin && !accessTokenInfo) {
return;
}
@@ -25,7 +27,7 @@ export default function getApi(protectedPath: boolean): DefaultApi | undefined {
: undefined,
});
setApi(new DefaultApi(apiConfiguration));
- }, [!accessTokenInfo, protectedPath]);
+ }, [!accessTokenInfo, hasAuthCookie]);
return api;
}
From 4226428f582821383e7fb1af06cfbadb919c1f7e Mon Sep 17 00:00:00 2001
From: Sara
Date: Wed, 22 Nov 2023 13:53:12 +0100
Subject: [PATCH 13/27] minor styling changes
---
.../transcripts/[transcriptId]/page.tsx | 2 ++
www/app/[domain]/transcripts/shareLink.tsx | 22 +++++++++++++++----
.../[domain]/transcripts/waveformLoading.tsx | 2 +-
3 files changed, 21 insertions(+), 5 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 472c573b..54eca9f9 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -15,6 +15,8 @@ import TranscriptTitle from "../transcriptTitle";
import Player from "../player";
import WaveformLoading from "../waveformLoading";
import { useRouter } from "next/navigation";
+import { faSpinner } from "@fortawesome/free-solid-svg-icons";
+import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
type TranscriptDetails = {
params: {
diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx
index e6449bd3..dd66d6cb 100644
--- a/www/app/[domain]/transcripts/shareLink.tsx
+++ b/www/app/[domain]/transcripts/shareLink.tsx
@@ -6,6 +6,8 @@ import SelectSearch from "react-select-search";
import "react-select-search/style.css";
import "../../styles/button.css";
import "../../styles/form.scss";
+import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
+import { faSpinner } from "@fortawesome/free-solid-svg-icons";
type ShareLinkProps = {
transcriptId: string;
@@ -20,6 +22,7 @@ const ShareLink = (props: ShareLinkProps) => {
const requireLogin = featureEnabled("requireLogin");
const [isOwner, setIsOwner] = useState(false);
const [shareMode, setShareMode] = useState(props.shareMode);
+ const [shareLoading, setShareLoading] = useState(false);
const userinfo = useFiefUserinfo();
const api = getApi();
@@ -46,6 +49,7 @@ const ShareLink = (props: ShareLinkProps) => {
const updateShareMode = async (selectedShareMode: string) => {
if (!api) return;
+ setShareLoading(true);
const updatedTranscript = await api.v1TranscriptUpdate({
transcriptId: props.transcriptId,
updateTranscript: {
@@ -53,6 +57,7 @@ const ShareLink = (props: ShareLinkProps) => {
},
});
setShareMode(updatedTranscript.shareMode);
+ setShareLoading(false);
};
const privacyEnabled = featureEnabled("privacy");
@@ -62,7 +67,7 @@ const ShareLink = (props: ShareLinkProps) => {
style={{ background: "rgba(96, 165, 250, 0.2)" }}
>
{requireLogin && (
-
+
{shareMode === "private" && (
This transcript is private and can only be accessed by you.
)}
@@ -76,7 +81,7 @@ const ShareLink = (props: ShareLinkProps) => {
)}
{isOwner && api && (
-
+
{
]}
value={shareMode}
onChange={updateShareMode}
+ closeOnSelect={true}
/>
-
+ {shareLoading && (
+
+
+
+ )}
+
)}
-
+
)}
{!requireLogin && (
<>
diff --git a/www/app/[domain]/transcripts/waveformLoading.tsx b/www/app/[domain]/transcripts/waveformLoading.tsx
index 68e0c80f..56540927 100644
--- a/www/app/[domain]/transcripts/waveformLoading.tsx
+++ b/www/app/[domain]/transcripts/waveformLoading.tsx
@@ -5,7 +5,7 @@ export default () => (
);
From aecc3a0c3bc7b4c6daf9474f8ef95a21726c22fc Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 15 Nov 2023 21:24:21 +0100
Subject: [PATCH 14/27] server: first attempts to split post pipeline as single
celery tasks
---
server/reflector/db/transcripts.py | 19 ++
.../reflector/pipelines/main_live_pipeline.py | 239 +++++++++++++++---
server/reflector/storage/base.py | 11 +-
server/reflector/storage/storage_aws.py | 20 +-
4 files changed, 241 insertions(+), 48 deletions(-)
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index f0dbc277..c0e8984b 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -106,6 +106,7 @@ class Transcript(BaseModel):
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
+ audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -140,6 +141,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self):
return self.data_path / "audio.json"
+ @property
+ def storage_audio_path(self):
+ return f"{self.id}/audio.mp3"
+
@property
def audio_waveform(self):
try:
@@ -283,5 +288,19 @@ class TranscriptController:
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
+ async def move_mp3_to_storage(self, transcript: Transcript):
+ """
+ Move mp3 file to storage
+ """
+ from reflector.storage import Storage
+
+ storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE)
+ await storage.put_file(
+ transcript.storage_audio_path,
+ self.audio_mp3_filename.read_bytes(),
+ )
+
+ await self.update(transcript, {"audio_location": "storage"})
+
transcripts_controller = TranscriptController()
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 3a9d1868..e2f305c4 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -12,6 +12,7 @@ It is directly linked to our data model.
"""
import asyncio
+import functools
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
@@ -55,6 +56,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
+from structlog import Logger
+
+
+def asynctask(f):
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ coro = f(*args, **kwargs)
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+ if loop and loop.is_running():
+ return loop.run_until_complete(coro)
+ return asyncio.run(coro)
+
+ return wrapper
def broadcast_to_sockets(func):
@@ -75,6 +92,22 @@ def broadcast_to_sockets(func):
return wrapper
+def get_transcript(func):
+ """
+ Decorator to fetch the transcript from the database from the first argument
+ """
+
+ async def wrapper(self, **kwargs):
+ transcript_id = kwargs.pop("transcript_id")
+ transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
+ if not transcript:
+ raise Exception("Transcript {transcript_id} not found")
+ tlogger = logger.bind(transcript_id=transcript.id)
+ return await func(self, transcript=transcript, logger=tlogger, **kwargs)
+
+ return wrapper
+
+
class StrValue(BaseModel):
value: str
@@ -99,6 +132,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
+ def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
+ return [
+ TitleSummaryWithIdProcessorType(
+ id=topic.id,
+ title=topic.title,
+ summary=topic.summary,
+ timestamp=topic.timestamp,
+ duration=topic.duration,
+ transcript=TranscriptProcessorType(words=topic.words),
+ )
+ for topic in transcript.topics
+ ]
+
@asynccontextmanager
async def transaction(self):
async with self._lock:
@@ -299,10 +345,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
- pipeline.logger.info(
- "Pipeline main live created",
- transcript_id=self.transcript_id,
- )
+ pipeline.logger.info("Pipeline main live created")
return pipeline
@@ -310,55 +353,28 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
- task_pipeline_main_post.delay(transcript_id=self.transcript_id)
+ pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
"""
- Diarization is a long time process, so we do it in a separate pipeline
- When done, adjust the short and final summary
+ Diarize the audio and update topics
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
- processors = []
- if settings.DIARIZATION_ENABLED:
- processors += [
- AudioDiarizationAutoProcessor(callback=self.on_topic),
- ]
-
- processors += [
- BroadcastProcessor(
- processors=[
- TranscriptFinalLongSummaryProcessor.as_threaded(
- callback=self.on_long_summary
- ),
- TranscriptFinalShortSummaryProcessor.as_threaded(
- callback=self.on_short_summary
- ),
- ]
- ),
- ]
- pipeline = Pipeline(*processors)
+ pipeline = Pipeline(
+ AudioDiarizationAutoProcessor(callback=self.on_topic),
+ )
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
- topics = [
- TitleSummaryWithIdProcessorType(
- id=topic.id,
- title=topic.title,
- summary=topic.summary,
- timestamp=topic.timestamp,
- duration=topic.duration,
- transcript=TranscriptProcessorType(words=topic.words),
- )
- for topic in transcript.topics
- ]
+ topics = self.get_transcript_topics(transcript)
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
@@ -386,15 +402,49 @@ class PipelineMainDiarization(PipelineMainBase):
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
- pipeline.logger.info(
- "Pipeline main post created", transcript_id=self.transcript_id
- )
+ pipeline.logger.info("Diarization pipeline created")
self.push(audio_diarization_input)
self.flush()
return pipeline
+class PipelineMainSummaries(PipelineMainBase):
+ """
+ Generate summaries from the topics
+ """
+
+ async def create(self) -> Pipeline:
+ self.prepare()
+ pipeline = Pipeline(
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalLongSummaryProcessor.as_threaded(
+ callback=self.on_long_summary
+ ),
+ TranscriptFinalShortSummaryProcessor.as_threaded(
+ callback=self.on_short_summary
+ ),
+ ]
+ ),
+ )
+ pipeline.options = self
+
+ # get transcript
+ transcript = await self.get_transcript()
+ pipeline.logger.bind(transcript_id=transcript.id)
+ pipeline.logger.info("Summaries pipeline created")
+
+ # push topics
+ topics = await self.get_transcript_topics(transcript)
+ for topic in topics:
+ self.push(topic)
+
+ self.flush()
+
+ return pipeline
+
+
@shared_task
def task_pipeline_main_post(transcript_id: str):
logger.info(
@@ -403,3 +453,112 @@ def task_pipeline_main_post(transcript_id: str):
)
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()
+
+
+@get_transcript
+async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
+ logger.info("Starting convert to mp3")
+
+ # If the audio wav is not available, just skip
+ wav_filename = transcript.audio_wav_filename
+ if not wav_filename.exists():
+ logger.warning("Wav file not found, may be already converted")
+ return
+
+ # Convert to mp3
+ mp3_filename = transcript.audio_mp3_filename
+
+ import av
+
+ input_container = av.open(wav_filename)
+ output_container = av.open(mp3_filename, "w")
+ input_audio_stream = input_container.streams.audio[0]
+ output_audio_stream = output_container.add_stream("mp3")
+ output_audio_stream.codec_context.set_parameters(
+ input_audio_stream.codec_context.parameters
+ )
+ for packet in input_container.demux(input_audio_stream):
+ for frame in packet.decode():
+ output_container.mux(frame)
+ input_container.close()
+ output_container.close()
+
+ logger.info("Convert to mp3 done")
+
+
+@get_transcript
+async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
+ logger.info("Starting upload mp3")
+
+ # If the audio mp3 is not available, just skip
+ mp3_filename = transcript.audio_mp3_filename
+ if not mp3_filename.exists():
+ logger.warning("Mp3 file not found, may be already uploaded")
+ return
+
+ # Upload to external storage and delete the file
+ await transcripts_controller.move_to_storage(transcript)
+ await transcripts_controller.unlink_mp3(transcript)
+
+ logger.info("Upload mp3 done")
+
+
+@get_transcript
+@asynctask
+async def pipeline_diarization(transcript: Transcript, logger: Logger):
+ logger.info("Starting diarization")
+ runner = PipelineMainDiarization(transcript_id=transcript.id)
+ await runner.start()
+ logger.info("Diarization done")
+
+
+@get_transcript
+@asynctask
+async def pipeline_summaries(transcript: Transcript, logger: Logger):
+ logger.info("Starting summaries")
+ runner = PipelineMainSummaries(transcript_id=transcript.id)
+ await runner.start()
+ logger.info("Summaries done")
+
+
+# ===================================================================
+# Celery tasks that can be called from the API
+# ===================================================================
+
+
+@shared_task
+@asynctask
+async def task_pipeline_convert_to_mp3(transcript_id: str):
+ await pipeline_convert_to_mp3(transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_upload_mp3(transcript_id: str):
+ await pipeline_upload_mp3(transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_diarization(transcript_id: str):
+ await pipeline_diarization(transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_summaries(transcript_id: str):
+ await pipeline_summaries(transcript_id)
+
+
+def pipeline_post(transcript_id: str):
+ """
+ Run the post pipeline
+ """
+ chain_mp3_and_diarize = (
+ task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
+ | task_pipeline_upload_mp3.si(transcript_id=transcript_id)
+ | task_pipeline_diarization.si(transcript_id=transcript_id)
+ )
+ chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id)
+ chain = chain_mp3_and_diarize | chain_summary
+ chain.delay()
diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py
index 5cdafdbf..7c44ff4d 100644
--- a/server/reflector/storage/base.py
+++ b/server/reflector/storage/base.py
@@ -1,6 +1,7 @@
+import importlib
+
from pydantic import BaseModel
from reflector.settings import settings
-import importlib
class FileResult(BaseModel):
@@ -17,14 +18,14 @@ class Storage:
cls._registry[name] = kclass
@classmethod
- def get_instance(cls, name, settings_prefix=""):
+ def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
- config = {}
+ config = {"folder": folder}
name_upper = name.upper()
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
@@ -34,6 +35,10 @@ class Storage:
return cls._registry[name](**config)
+ def __init__(self):
+ self.folder = ""
+ super().__init__()
+
async def put_file(self, filename: str, data: bytes) -> FileResult:
return await self._put_file(filename, data)
diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py
index 09a9c383..5ab02903 100644
--- a/server/reflector/storage/storage_aws.py
+++ b/server/reflector/storage/storage_aws.py
@@ -1,6 +1,6 @@
import aioboto3
-from reflector.storage.base import Storage, FileResult
from reflector.logger import logger
+from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage):
@@ -22,9 +22,14 @@ class AwsStorage(Storage):
super().__init__()
self.aws_bucket_name = aws_bucket_name
- self.aws_folder = ""
+ folder = ""
if "/" in aws_bucket_name:
- self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
+ self.aws_bucket_name, folder = aws_bucket_name.split("/", 1)
+ if folder:
+ if not self.folder:
+ self.folder = folder
+ else:
+ self.folder = f"{self.folder}/{folder}"
self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
@@ -34,7 +39,7 @@ class AwsStorage(Storage):
async def _put_file(self, filename: str, data: bytes) -> FileResult:
bucket = self.aws_bucket_name
- folder = self.aws_folder
+ folder = self.folder
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
@@ -44,6 +49,11 @@ class AwsStorage(Storage):
Body=data,
)
+ async def get_file_url(self, filename: str) -> FileResult:
+ bucket = self.aws_bucket_name
+ folder = self.folder
+ s3filename = f"{folder}/{filename}" if folder else filename
+ async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": s3filename},
@@ -57,7 +67,7 @@ class AwsStorage(Storage):
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name
- folder = self.aws_folder
+ folder = self.folder
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
From 88f443e2c25cd8bf9158b1c83bb8c76a3813d843 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 16 Nov 2023 14:32:18 +0100
Subject: [PATCH 15/27] server: revert change on storage folder
---
server/reflector/storage/base.py | 14 ++++++++------
server/reflector/storage/storage_aws.py | 22 +++++++---------------
2 files changed, 15 insertions(+), 21 deletions(-)
diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py
index 7c44ff4d..a457ddf8 100644
--- a/server/reflector/storage/base.py
+++ b/server/reflector/storage/base.py
@@ -18,14 +18,14 @@ class Storage:
cls._registry[name] = kclass
@classmethod
- def get_instance(cls, name: str, settings_prefix: str = "", folder: str = ""):
+ def get_instance(cls, name: str, settings_prefix: str = ""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
- config = {"folder": folder}
+ config = {}
name_upper = name.upper()
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
@@ -35,10 +35,6 @@ class Storage:
return cls._registry[name](**config)
- def __init__(self):
- self.folder = ""
- super().__init__()
-
async def put_file(self, filename: str, data: bytes) -> FileResult:
return await self._put_file(filename, data)
@@ -50,3 +46,9 @@ class Storage:
async def _delete_file(self, filename: str):
raise NotImplementedError
+
+ async def get_file_url(self, filename: str) -> str:
+ return await self._get_file_url(filename)
+
+ async def _get_file_url(self, filename: str) -> str:
+ raise NotImplementedError
diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py
index 5ab02903..d2313293 100644
--- a/server/reflector/storage/storage_aws.py
+++ b/server/reflector/storage/storage_aws.py
@@ -22,14 +22,9 @@ class AwsStorage(Storage):
super().__init__()
self.aws_bucket_name = aws_bucket_name
- folder = ""
+ self.aws_folder = ""
if "/" in aws_bucket_name:
- self.aws_bucket_name, folder = aws_bucket_name.split("/", 1)
- if folder:
- if not self.folder:
- self.folder = folder
- else:
- self.folder = f"{self.folder}/{folder}"
+ self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
@@ -39,7 +34,7 @@ class AwsStorage(Storage):
async def _put_file(self, filename: str, data: bytes) -> FileResult:
bucket = self.aws_bucket_name
- folder = self.folder
+ folder = self.aws_folder
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
@@ -49,9 +44,9 @@ class AwsStorage(Storage):
Body=data,
)
- async def get_file_url(self, filename: str) -> FileResult:
+ async def _get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
- folder = self.folder
+ folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url(
@@ -60,14 +55,11 @@ class AwsStorage(Storage):
ExpiresIn=3600,
)
- return FileResult(
- filename=filename,
- url=presigned_url,
- )
+ return presigned_url
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name
- folder = self.folder
+ folder = self.aws_folder
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
From 06b29d9bd4a1d1b7f6265756be57c5602888673a Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 16 Nov 2023 14:34:33 +0100
Subject: [PATCH 16/27] server: add audio_location and move to external storage
if possible
---
.../versions/f819277e5169_audio_location.py | 43 ++++
server/reflector/db/transcripts.py | 67 ++++-
.../reflector/pipelines/main_live_pipeline.py | 234 ++++++++++--------
server/reflector/pipelines/runner.py | 3 +-
server/reflector/settings.py | 5 +-
5 files changed, 238 insertions(+), 114 deletions(-)
create mode 100644 server/migrations/versions/f819277e5169_audio_location.py
diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py
new file mode 100644
index 00000000..576b02bd
--- /dev/null
+++ b/server/migrations/versions/f819277e5169_audio_location.py
@@ -0,0 +1,43 @@
+"""audio_location
+
+Revision ID: f819277e5169
+Revises: 4814901632bc
+Create Date: 2023-11-16 10:29:09.351664
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "f819277e5169"
+down_revision: Union[str, None] = "4814901632bc"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "transcript",
+ sa.Column(
+ "audio_location", sa.String(), server_default="local", nullable=False
+ ),
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "transcript",
+ sa.Column(
+ "share_mode",
+ sa.VARCHAR(),
+ server_default=sa.text("'private'"),
+ nullable=False,
+ ),
+ )
+ # ### end Alembic commands ###
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index c0e8984b..44a6d56b 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
+from reflector.storage import Storage
transcripts = sqlalchemy.Table(
"transcript",
@@ -27,20 +28,33 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
+ sqlalchemy.Column(
+ "audio_location",
+ sqlalchemy.String,
+ nullable=False,
+ server_default="local",
+ ),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)
-def generate_uuid4():
+def generate_uuid4() -> str:
return str(uuid4())
-def generate_transcript_name():
+def generate_transcript_name() -> str:
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
+def get_storage() -> Storage:
+ return Storage.get_instance(
+ name=settings.TRANSCRIPT_STORAGE_BACKEND,
+ settings_prefix="TRANSCRIPT_STORAGE_",
+ )
+
+
class AudioWaveform(BaseModel):
data: list[float]
@@ -133,6 +147,10 @@ class Transcript(BaseModel):
def data_path(self):
return Path(settings.DATA_DIR) / self.id
+ @property
+ def audio_wav_filename(self):
+ return self.data_path / "audio.wav"
+
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@@ -157,6 +175,40 @@ class Transcript(BaseModel):
return AudioWaveform(data=data)
+ async def get_audio_url(self) -> str:
+ if self.audio_location == "local":
+ return self._generate_local_audio_link()
+ elif self.audio_location == "storage":
+ return await self._generate_storage_audio_link()
+ raise Exception(f"Unknown audio location {self.audio_location}")
+
+ async def _generate_storage_audio_link(self) -> str:
+ return await get_storage().get_file_url(self.storage_audio_path)
+
+ def _generate_local_audio_link(self) -> str:
+ # we need to create an url to be used for diarization
+ # we can't use the audio_mp3_filename because it's not accessible
+ # from the diarization processor
+ from datetime import timedelta
+
+ from reflector.app import app
+ from reflector.views.transcripts import create_access_token
+
+ path = app.url_path_for(
+ "transcript_get_audio_mp3",
+ transcript_id=self.id,
+ )
+ url = f"{settings.BASE_URL}{path}"
+ if self.user_id:
+ # we pass token only if the user_id is set
+ # otherwise, the audio is public
+ token = create_access_token(
+ {"sub": self.user_id},
+ expires_delta=timedelta(minutes=15),
+ )
+ url += f"?token={token}"
+ return url
+
class TranscriptController:
async def get_all(
@@ -292,15 +344,18 @@ class TranscriptController:
"""
Move mp3 file to storage
"""
- from reflector.storage import Storage
- storage = Storage.get_instance(settings.TRANSCRIPT_STORAGE)
- await storage.put_file(
+ # store the audio on external storage
+ await get_storage().put_file(
transcript.storage_audio_path,
- self.audio_mp3_filename.read_bytes(),
+ transcript.audio_mp3_filename.read_bytes(),
)
+ # indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
+ # unlink the local file
+ transcript.audio_mp3_filename.unlink(missing_ok=True)
+
transcripts_controller = TranscriptController()
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index e2f305c4..83b57949 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -14,12 +14,9 @@ It is directly linked to our data model.
import asyncio
import functools
from contextlib import asynccontextmanager
-from datetime import timedelta
-from pathlib import Path
-from celery import shared_task
+from celery import chord, group, shared_task
from pydantic import BaseModel
-from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptDuration,
@@ -56,7 +53,7 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
-from structlog import Logger
+from structlog import BoundLogger as Logger
def asynctask(f):
@@ -97,13 +94,17 @@ def get_transcript(func):
Decorator to fetch the transcript from the database from the first argument
"""
- async def wrapper(self, **kwargs):
+ async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
- return await func(self, transcript=transcript, logger=tlogger, **kwargs)
+ try:
+ return await func(transcript=transcript, logger=tlogger, **kwargs)
+ except Exception as exc:
+ tlogger.error("Pipeline error", exc_info=exc)
+ raise
return wrapper
@@ -162,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
"flush": "processing",
"error": "error",
}
- elif isinstance(self, PipelineMainDiarization):
+ elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = {
"push": "processing",
"flush": "processing",
@@ -170,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
"ended": "ended",
}
else:
- raise Exception(f"Runner {self.__class__} is missing status mapping")
+ # intermediate pipeline don't update status
+ return
# mutate to model status
status = status_mapping.get(status)
@@ -308,9 +310,10 @@ class PipelineMainBase(PipelineRunner):
class PipelineMainLive(PipelineMainBase):
- audio_filename: Path | None = None
- source_language: str = "en"
- target_language: str = "en"
+ """
+ Main pipeline for live streaming, attach to RTC connection
+ Any long post process should be done in the post pipeline
+ """
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
@@ -320,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
processors = [
AudioFileWriterProcessor(
- path=transcript.audio_mp3_filename,
+ path=transcript.audio_wav_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
@@ -329,15 +332,11 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
- BroadcastProcessor(
- processors=[
- TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
- AudioWaveformProcessor.as_threaded(
- audio_path=transcript.audio_mp3_filename,
- waveform_path=transcript.audio_waveform_filename,
- on_waveform=self.on_waveform,
- ),
- ]
+ # XXX move as a task
+ AudioWaveformProcessor.as_threaded(
+ audio_path=transcript.audio_mp3_filename,
+ waveform_path=transcript.audio_waveform_filename,
+ on_waveform=self.on_waveform,
),
]
pipeline = Pipeline(*processors)
@@ -374,28 +373,16 @@ class PipelineMainDiarization(PipelineMainBase):
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
+
+ # diarization works only if the file is uploaded to an external storage
+ if transcript.audio_location == "local":
+ pipeline.logger.info("Audio is local, skipping diarization")
+ return
+
topics = self.get_transcript_topics(transcript)
-
- # we need to create an url to be used for diarization
- # we can't use the audio_mp3_filename because it's not accessible
- # from the diarization processor
- from reflector.views.transcripts import create_access_token
-
- path = app.url_path_for(
- "transcript_get_audio_mp3",
- transcript_id=transcript.id,
- )
- url = f"{settings.BASE_URL}{path}"
- if transcript.user_id:
- # we pass token only if the user_id is set
- # otherwise, the audio is public
- token = create_access_token(
- {"sub": transcript.user_id},
- expires_delta=timedelta(minutes=15),
- )
- url += f"?token={token}"
+ audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
- audio_url=url,
+ audio_url=audio_url,
topics=topics,
)
@@ -409,14 +396,60 @@ class PipelineMainDiarization(PipelineMainBase):
return pipeline
-class PipelineMainSummaries(PipelineMainBase):
+class PipelineMainFromTopics(PipelineMainBase):
+ """
+ Pseudo class for generating a pipeline from topics
+ """
+
+ def get_processors(self) -> list:
+ raise NotImplementedError
+
+ async def create(self) -> Pipeline:
+ self.prepare()
+ processors = self.get_processors()
+ pipeline = Pipeline(*processors)
+ pipeline.options = self
+
+ # get transcript
+ transcript = await self.get_transcript()
+ pipeline.logger.bind(transcript_id=transcript.id)
+ pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
+
+ # push topics
+ topics = self.get_transcript_topics(transcript)
+ for topic in topics:
+ self.push(topic)
+
+ self.flush()
+
+ return pipeline
+
+
+class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
+ """
+ Generate title from the topics
+ """
+
+ def get_processors(self) -> list:
+ return [
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
+ TranscriptFinalShortSummaryProcessor.as_threaded(
+ callback=self.on_short_summary
+ ),
+ ]
+ )
+ ]
+
+
+class PipelineMainFinalSummaries(PipelineMainFromTopics):
"""
Generate summaries from the topics
"""
- async def create(self) -> Pipeline:
- self.prepare()
- pipeline = Pipeline(
+ def get_processors(self) -> list:
+ return [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -427,32 +460,7 @@ class PipelineMainSummaries(PipelineMainBase):
),
]
),
- )
- pipeline.options = self
-
- # get transcript
- transcript = await self.get_transcript()
- pipeline.logger.bind(transcript_id=transcript.id)
- pipeline.logger.info("Summaries pipeline created")
-
- # push topics
- topics = await self.get_transcript_topics(transcript)
- for topic in topics:
- self.push(topic)
-
- self.flush()
-
- return pipeline
-
-
-@shared_task
-def task_pipeline_main_post(transcript_id: str):
- logger.info(
- "Starting main post pipeline",
- transcript_id=transcript_id,
- )
- runner = PipelineMainDiarization(transcript_id=transcript_id)
- runner.start_sync()
+ ]
@get_transcript
@@ -470,24 +478,26 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
import av
- input_container = av.open(wav_filename)
- output_container = av.open(mp3_filename, "w")
- input_audio_stream = input_container.streams.audio[0]
- output_audio_stream = output_container.add_stream("mp3")
- output_audio_stream.codec_context.set_parameters(
- input_audio_stream.codec_context.parameters
- )
- for packet in input_container.demux(input_audio_stream):
- for frame in packet.decode():
- output_container.mux(frame)
- input_container.close()
- output_container.close()
+ with av.open(wav_filename.as_posix()) as in_container:
+ in_stream = in_container.streams.audio[0]
+ with av.open(mp3_filename.as_posix(), "w") as out_container:
+ out_stream = out_container.add_stream("mp3")
+ for frame in in_container.decode(in_stream):
+ for packet in out_stream.encode(frame):
+ out_container.mux(packet)
+
+ # Delete the wav file
+ transcript.audio_wav_filename.unlink(missing_ok=True)
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
+ if not settings.TRANSCRIPT_STORAGE_BACKEND:
+ logger.info("No storage backend configured, skipping mp3 upload")
+ return
+
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
@@ -497,27 +507,32 @@ async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
return
# Upload to external storage and delete the file
- await transcripts_controller.move_to_storage(transcript)
- await transcripts_controller.unlink_mp3(transcript)
+ await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done")
@get_transcript
-@asynctask
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
- await runner.start()
+ await runner.run()
logger.info("Diarization done")
@get_transcript
-@asynctask
+async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
+ logger.info("Starting title and short summary")
+ runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
+ await runner.run()
+ logger.info("Title and short summary done")
+
+
+@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
- runner = PipelineMainSummaries(transcript_id=transcript.id)
- await runner.start()
+ runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
+ await runner.run()
logger.info("Summaries done")
@@ -528,29 +543,35 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
@shared_task
@asynctask
-async def task_pipeline_convert_to_mp3(transcript_id: str):
- await pipeline_convert_to_mp3(transcript_id)
+async def task_pipeline_convert_to_mp3(*, transcript_id: str):
+ await pipeline_convert_to_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
-async def task_pipeline_upload_mp3(transcript_id: str):
- await pipeline_upload_mp3(transcript_id)
+async def task_pipeline_upload_mp3(*, transcript_id: str):
+ await pipeline_upload_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
-async def task_pipeline_diarization(transcript_id: str):
- await pipeline_diarization(transcript_id)
+async def task_pipeline_diarization(*, transcript_id: str):
+ await pipeline_diarization(transcript_id=transcript_id)
@shared_task
@asynctask
-async def task_pipeline_summaries(transcript_id: str):
- await pipeline_summaries(transcript_id)
+async def task_pipeline_title_and_short_summary(*, transcript_id: str):
+ await pipeline_title_and_short_summary(transcript_id=transcript_id)
-def pipeline_post(transcript_id: str):
+@shared_task
+@asynctask
+async def task_pipeline_final_summaries(*, transcript_id: str):
+ await pipeline_summaries(transcript_id=transcript_id)
+
+
+def pipeline_post(*, transcript_id: str):
"""
Run the post pipeline
"""
@@ -559,6 +580,15 @@ def pipeline_post(transcript_id: str):
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
)
- chain_summary = task_pipeline_summaries.si(transcript_id=transcript_id)
- chain = chain_mp3_and_diarize | chain_summary
+ chain_title_preview = task_pipeline_title_and_short_summary.si(
+ transcript_id=transcript_id
+ )
+ chain_final_summaries = task_pipeline_final_summaries.si(
+ transcript_id=transcript_id
+ )
+
+ chain = chord(
+ group(chain_mp3_and_diarize, chain_title_preview),
+ chain_final_summaries,
+ )
chain.delay()
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
index a1e137a7..4105d51f 100644
--- a/server/reflector/pipelines/runner.py
+++ b/server/reflector/pipelines/runner.py
@@ -119,8 +119,7 @@ class PipelineRunner(BaseModel):
self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
- if self.on_ended:
- await self.on_ended()
+ raise
async def cmd_push(self, data):
if self._is_first_push:
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index 65412310..2c68c4e5 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage
- TRANSCRIPT_STORAGE_BACKEND: str = "aws"
+ TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
- # Transcript MP3 storage
- TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
-
# LLM
# available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
From 5ffa931822ff5b7c5205f942ca2332f52b9ef892 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 16 Nov 2023 14:45:40 +0100
Subject: [PATCH 17/27] server: update backend tests results (rpc does not work
with chords)
---
.../migrations/versions/f819277e5169_audio_location.py | 10 +---------
server/tests/conftest.py | 8 +++++++-
2 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py
index 576b02bd..061abec4 100644
--- a/server/migrations/versions/f819277e5169_audio_location.py
+++ b/server/migrations/versions/f819277e5169_audio_location.py
@@ -31,13 +31,5 @@ def upgrade() -> None:
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column(
- "transcript",
- sa.Column(
- "share_mode",
- sa.VARCHAR(),
- server_default=sa.text("'private'"),
- nullable=False,
- ),
- )
+ op.drop_column("transcript", "audio_location")
# ### end Alembic commands ###
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index aafca9fd..aaf42884 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -133,4 +133,10 @@ def celery_enable_logging():
@pytest.fixture(scope="session")
def celery_config():
- return {"broker_url": "memory://", "result_backend": "rpc"}
+ import tempfile
+
+ with tempfile.NamedTemporaryFile() as fd:
+ yield {
+ "broker_url": "memory://",
+ "result_backend": "db+sqlite://" + fd.name,
+ }
From 99b973f36f5fbbe1c193dd00dc2233fc82eddb4d Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 17 Nov 2023 14:27:53 +0100
Subject: [PATCH 18/27] server: fix tests
---
server/reflector/pipelines/runner.py | 8 ++++++
server/tests/conftest.py | 36 +++++++++++++++++++++----
server/tests/test_transcripts_rtc_ws.py | 4 +++
3 files changed, 43 insertions(+), 5 deletions(-)
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
index 4105d51f..708a4265 100644
--- a/server/reflector/pipelines/runner.py
+++ b/server/reflector/pipelines/runner.py
@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
if not self.pipeline:
self.pipeline = await self.create()
+ if not self.pipeline:
+ # no pipeline created in create, just finish it then.
+ await self._set_status("ended")
+ self._ev_done.set()
+ if self.on_ended:
+ await self.on_ended()
+ return
+
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index aaf42884..532ebff9 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -1,4 +1,5 @@
from unittest.mock import patch
+from tempfile import NamedTemporaryFile
import pytest
@@ -7,7 +8,6 @@ import pytest
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
- from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
@@ -103,6 +103,25 @@ async def dummy_llm():
yield
+@pytest.fixture
+async def dummy_storage():
+ from reflector.storage.base import Storage
+
+ class DummyStorage(Storage):
+ async def _put_file(self, *args, **kwargs):
+ pass
+
+ async def _delete_file(self, *args, **kwargs):
+ pass
+
+ async def _get_file_url(self, *args, **kwargs):
+ return "http://fake_server/audio.mp3"
+
+ with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
+ mock_storage.return_value = DummyStorage()
+ yield
+
+
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
@@ -133,10 +152,17 @@ def celery_enable_logging():
@pytest.fixture(scope="session")
def celery_config():
- import tempfile
-
- with tempfile.NamedTemporaryFile() as fd:
+ with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
- "result_backend": "db+sqlite://" + fd.name,
+ "result_backend": f"db+sqlite:///{f.name}",
}
+
+
+@pytest.fixture(scope="session")
+def fake_mp3_upload():
+ with patch(
+ "reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
+ ) as mock_move:
+ mock_move.return_value = True
+ yield
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index b33b1db5..8502a0d9 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
dummy_transcript,
dummy_processors,
dummy_diarization,
+ dummy_storage,
+ fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
@@ -220,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_transcript,
dummy_processors,
dummy_diarization,
+ dummy_storage,
+ fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
From 794d08c3a88d55a2ac6ea5faecd117697d838612 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Tue, 21 Nov 2023 14:46:16 +0100
Subject: [PATCH 19/27] server: redirect to storage url
---
server/reflector/views/transcripts.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 6909b8ae..7496b26c 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -233,6 +233,12 @@ async def transcript_get_audio_mp3(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
+ if transcript.audio_location == "storage":
+ url = transcript.get_audio_url()
+ from fastapi.responses import RedirectResponse
+
+ return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND)
+
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
From 0e5c0f66d91fa61f7bfee283a89f96f13dc57fef Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Tue, 21 Nov 2023 15:40:49 +0100
Subject: [PATCH 20/27] server: move waveform out of the live pipeline
---
.../reflector/pipelines/main_live_pipeline.py | 46 +++++++++++++++----
server/reflector/views/transcripts.py | 22 +++++++--
2 files changed, 54 insertions(+), 14 deletions(-)
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 83b57949..b182f421 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -332,12 +332,6 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
- # XXX move as a task
- AudioWaveformProcessor.as_threaded(
- audio_path=transcript.audio_mp3_filename,
- waveform_path=transcript.audio_waveform_filename,
- on_waveform=self.on_waveform,
- ),
]
pipeline = Pipeline(*processors)
pipeline.options = self
@@ -406,12 +400,14 @@ class PipelineMainFromTopics(PipelineMainBase):
async def create(self) -> Pipeline:
self.prepare()
+
+ # get transcript
+ self._transcript = transcript = await self.get_transcript()
+
+ # create pipeline
processors = self.get_processors()
pipeline = Pipeline(*processors)
pipeline.options = self
-
- # get transcript
- transcript = await self.get_transcript()
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
@@ -463,6 +459,29 @@ class PipelineMainFinalSummaries(PipelineMainFromTopics):
]
+class PipelineMainWaveform(PipelineMainFromTopics):
+ """
+ Generate waveform
+ """
+
+ def get_processors(self) -> list:
+ return [
+ AudioWaveformProcessor.as_threaded(
+ audio_path=self._transcript.audio_wav_filename,
+ waveform_path=self._transcript.audio_waveform_filename,
+ on_waveform=self.on_waveform,
+ ),
+ ]
+
+
+@get_transcript
+async def pipeline_waveform(transcript: Transcript, logger: Logger):
+ logger.info("Starting waveform")
+ runner = PipelineMainWaveform(transcript_id=transcript.id)
+ await runner.run()
+ logger.info("Waveform done")
+
+
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
@@ -541,6 +560,12 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
# ===================================================================
+@shared_task
+@asynctask
+async def task_pipeline_waveform(*, transcript_id: str):
+ await pipeline_waveform(transcript_id=transcript_id)
+
+
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
@@ -576,7 +601,8 @@ def pipeline_post(*, transcript_id: str):
Run the post pipeline
"""
chain_mp3_and_diarize = (
- task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
+ task_pipeline_waveform.si(transcript_id=transcript_id)
+ | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 7496b26c..125aa311 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -1,12 +1,14 @@
from datetime import datetime, timedelta
from typing import Annotated, Optional
+import httpx
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
+ Response,
WebSocket,
WebSocketDisconnect,
status,
@@ -234,10 +236,22 @@ async def transcript_get_audio_mp3(
raise HTTPException(status_code=404, detail="Transcript not found")
if transcript.audio_location == "storage":
- url = transcript.get_audio_url()
- from fastapi.responses import RedirectResponse
+ # proxy S3 file, to prevent issue with CORS
+ url = await transcript.get_audio_url()
+ headers = {}
- return RedirectResponse(url=url, status_code=status.HTTP_302_FOUND)
+ copy_headers = ["range", "accept-encoding"]
+ for header in copy_headers:
+ if header in request.headers:
+ headers[header] = request.headers[header]
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.request(request.method, url, headers=headers)
+ return Response(
+ content=resp.content,
+ status_code=resp.status_code,
+ headers=resp.headers,
+ )
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
@@ -263,7 +277,7 @@ async def transcript_get_audio_waveform(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
- if not transcript.audio_mp3_filename.exists():
+ if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform
From f8407874f77172c3aafa379f5089e80b00533064 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 23 Nov 2023 12:41:39 +0100
Subject: [PATCH 21/27] server: fixes share_mode script
---
.../versions/0fea6d96b096_add_share_mode.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/server/migrations/versions/0fea6d96b096_add_share_mode.py b/server/migrations/versions/0fea6d96b096_add_share_mode.py
index 52a72d48..48746c3b 100644
--- a/server/migrations/versions/0fea6d96b096_add_share_mode.py
+++ b/server/migrations/versions/0fea6d96b096_add_share_mode.py
@@ -1,7 +1,7 @@
"""add share_mode
Revision ID: 0fea6d96b096
-Revises: 38a927dcb099
+Revises: f819277e5169
Create Date: 2023-11-07 11:12:21.614198
"""
@@ -12,19 +12,22 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
-revision: str = '0fea6d96b096'
-down_revision: Union[str, None] = '38a927dcb099'
+revision: str = "0fea6d96b096"
+down_revision: Union[str, None] = "f819277e5169"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('transcript', sa.Column('share_mode', sa.String(), server_default='private', nullable=False))
+ op.add_column(
+ "transcript",
+ sa.Column("share_mode", sa.String(), server_default="private", nullable=False),
+ )
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column('transcript', 'share_mode')
+ op.drop_column("transcript", "share_mode")
# ### end Alembic commands ###
From 3ebb21923bffa319ba9ec0c771de3623ccbeb324 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 29 Nov 2023 20:34:43 +0100
Subject: [PATCH 22/27] server: enhance diarization algorithm
---
.../reflector/processors/audio_diarization.py | 161 +++++++++++++++++-
.../tests/test_processor_audio_diarization.py | 140 +++++++++++++++
2 files changed, 294 insertions(+), 7 deletions(-)
create mode 100644 server/tests/test_processor_audio_diarization.py
diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py
index 82c6a553..69eab5b7 100644
--- a/server/reflector/processors/audio_diarization.py
+++ b/server/reflector/processors/audio_diarization.py
@@ -1,5 +1,5 @@
from reflector.processors.base import Processor
-from reflector.processors.types import AudioDiarizationInput, TitleSummary
+from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word
class AudioDiarizationProcessor(Processor):
@@ -19,12 +19,12 @@ class AudioDiarizationProcessor(Processor):
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
- # mutate in place
- for topic in data.topics:
- for word in topic.transcript.words:
- for d in diarization:
- if d["start"] <= word.start <= d["end"]:
- word.speaker = d["speaker"]
+ # create a view of words based on topics
+ # the current algorithm is using words index, we cannot use a generator
+ words = list(self.iter_words_from_topics(data.topics))
+
+ # assign speaker to words (mutate the words list)
+ self.assign_speaker(words, diarization)
# emit them
for topic in data.topics:
@@ -32,3 +32,150 @@ class AudioDiarizationProcessor(Processor):
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError
+
+ def assign_speaker(self, words: list[Word], diarization: list[dict]):
+ self._diarization_remove_overlap(diarization)
+ self._diarization_remove_segment_without_words(words, diarization)
+ self._diarization_merge_same_speaker(words, diarization)
+ self._diarization_assign_speaker(words, diarization)
+
+ def iter_words_from_topics(self, topics: TitleSummary):
+ for topic in topics:
+ for word in topic.transcript.words:
+ yield word
+
+ def is_word_continuation(self, word_prev, word):
+ """
+ Return True if the word is a continuation of the previous word
+ by checking if the previous word is ending with a punctuation
+ or if the current word is starting with a capital letter
+ """
+ # is word_prev ending with a punctuation ?
+ if word_prev.text and word_prev.text[-1] in ".?!":
+ return False
+ elif word.text and word.text[0].isupper():
+ return False
+ return True
+
+ def _diarization_remove_overlap(self, diarization: list[dict]):
+ """
+ Remove overlap in diarization results
+
+ When using a diarization algorithm, it's possible to have overlapping segments
+ This function remove the overlap by keeping the longest segment
+
+ Warning: this function mutate the diarization list
+ """
+ # remove overlap by keeping the longest segment
+ diarization_idx = 0
+ while diarization_idx < len(diarization) - 1:
+ d = diarization[diarization_idx]
+ dnext = diarization[diarization_idx + 1]
+ if d["end"] > dnext["start"]:
+ # remove the shortest segment
+ if d["end"] - d["start"] > dnext["end"] - dnext["start"]:
+ # remove next segment
+ diarization.pop(diarization_idx + 1)
+ else:
+ # remove current segment
+ diarization.pop(diarization_idx)
+ else:
+ diarization_idx += 1
+
+ def _diarization_remove_segment_without_words(
+ self, words: list[Word], diarization: list[dict]
+ ):
+ """
+ Remove diarization segments without words
+
+ Warning: this function mutate the diarization list
+ """
+ # count the number of words for each diarization segment
+ diarization_count = []
+ for d in diarization:
+ start = d["start"]
+ end = d["end"]
+ count = 0
+ for word in words:
+ if start <= word.start < end:
+ count += 1
+ elif start < word.end <= end:
+ count += 1
+ diarization_count.append(count)
+
+ # remove diarization segments with no words
+ diarization_idx = 0
+ while diarization_idx < len(diarization):
+ if diarization_count[diarization_idx] == 0:
+ diarization.pop(diarization_idx)
+ diarization_count.pop(diarization_idx)
+ else:
+ diarization_idx += 1
+
+ def _diarization_merge_same_speaker(
+ self, words: list[Word], diarization: list[dict]
+ ):
+ """
+ Merge diarization contigous segments with the same speaker
+
+ Warning: this function mutate the diarization list
+ """
+ # merge segment with same speaker
+ diarization_idx = 0
+ while diarization_idx < len(diarization) - 1:
+ d = diarization[diarization_idx]
+ dnext = diarization[diarization_idx + 1]
+ if d["speaker"] == dnext["speaker"]:
+ diarization[diarization_idx]["end"] = dnext["end"]
+ diarization.pop(diarization_idx + 1)
+ else:
+ diarization_idx += 1
+
+ def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
+ """
+ Assign speaker to words based on diarization
+
+ Warning: this function mutate the words list
+ """
+
+ word_idx = 0
+ last_speaker = None
+ for d in diarization:
+ start = d["start"]
+ end = d["end"]
+ speaker = d["speaker"]
+
+ # diarization may start after the first set of words
+ # in this case, we assign the last speaker
+ for word in words[word_idx:]:
+ if word.start < start:
+ # speaker change, but what make sense for assigning the word ?
+ # If it's a new sentence, assign with the new speaker
+ # If it's a continuation, assign with the last speaker
+ is_continuation = False
+ if word_idx > 0 and word_idx < len(words) - 1:
+ is_continuation = self.is_word_continuation(
+ *words[word_idx - 1 : word_idx + 1]
+ )
+ if is_continuation:
+ word.speaker = last_speaker
+ else:
+ word.speaker = speaker
+ last_speaker = speaker
+ word_idx += 1
+ else:
+ break
+
+ # now continue to assign speaker until the word starts after the end
+ for word in words[word_idx:]:
+ if start <= word.start < end:
+ last_speaker = speaker
+ word.speaker = speaker
+ word_idx += 1
+ elif word.start > end:
+ break
+
+ # no more diarization available,
+ # assign last speaker to all words without speaker
+ for word in words[word_idx:]:
+ word.speaker = last_speaker
diff --git a/server/tests/test_processor_audio_diarization.py b/server/tests/test_processor_audio_diarization.py
new file mode 100644
index 00000000..00935a49
--- /dev/null
+++ b/server/tests/test_processor_audio_diarization.py
@@ -0,0 +1,140 @@
+import pytest
+from unittest import mock
+
+
+@pytest.mark.parametrize(
+ "name,diarization,expected",
+ [
+ [
+ "no overlap",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "same speaker",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "A"},
+ ],
+ ["A", "A", "A", "A"],
+ ],
+ [
+ # first segment is removed because it overlap
+ # with the second segment, and it is smaller
+ "overlap at 0.5s",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 0.5, "end": 2.0, "speaker": "B"},
+ ],
+ ["B", "B", "B", "B"],
+ ],
+ [
+ "junk segment at 0.5s for 0.2s",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 0.5, "end": 0.7, "speaker": "B"},
+ {"start": 1, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "start without diarization",
+ [
+ {"start": 0.5, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "end missing diarization",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 1.5, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "continuation of next speaker",
+ [
+ {"start": 0.0, "end": 0.9, "speaker": "A"},
+ {"start": 1.5, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "continuation of previous speaker",
+ [
+ {"start": 0.0, "end": 0.5, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "segment without words",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ {"start": 2.0, "end": 3.0, "speaker": "X"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ ],
+)
+@pytest.mark.asyncio
+async def test_processors_audio_diarization(event_loop, name, diarization, expected):
+ from reflector.processors.audio_diarization import AudioDiarizationProcessor
+ from reflector.processors.types import (
+ TitleSummaryWithId,
+ Transcript,
+ Word,
+ AudioDiarizationInput,
+ )
+
+ # create fake topic
+ topics = [
+ TitleSummaryWithId(
+ id="1",
+ title="Title1",
+ summary="Summary1",
+ timestamp=0.0,
+ duration=1.0,
+ transcript=Transcript(
+ words=[
+ Word(text="Word1", start=0.0, end=0.5),
+ Word(text="word2.", start=0.5, end=1.0),
+ ]
+ ),
+ ),
+ TitleSummaryWithId(
+ id="2",
+ title="Title2",
+ summary="Summary2",
+ timestamp=0.0,
+ duration=1.0,
+ transcript=Transcript(
+ words=[
+ Word(text="Word3", start=1.0, end=1.5),
+ Word(text="word4.", start=1.5, end=2.0),
+ ]
+ ),
+ ),
+ ]
+
+ diarizer = AudioDiarizationProcessor()
+ with mock.patch.object(diarizer, "_diarize") as mock_diarize:
+ mock_diarize.return_value = diarization
+
+ data = AudioDiarizationInput(
+ audio_url="https://example.com/audio.mp3",
+ topics=topics,
+ )
+ await diarizer._push(data)
+
+ # check that the speaker has been assigned to the words
+ assert topics[0].transcript.words[0].speaker == expected[0]
+ assert topics[0].transcript.words[1].speaker == expected[1]
+ assert topics[1].transcript.words[0].speaker == expected[2]
+ assert topics[1].transcript.words[1].speaker == expected[3]
From eae01c1495e63278a6c40ef44153b767fddd3ca6 Mon Sep 17 00:00:00 2001
From: projects-g <63178974+projects-g@users.noreply.github.com>
Date: Thu, 30 Nov 2023 22:00:06 +0530
Subject: [PATCH 23/27] Change diarization internal flow (#320)
* change diarization internal flow
---
server/gpu/modal/reflector_diarizer.py | 188 ++++++++++++++++++
.../processors/audio_diarization_modal.py | 2 +-
2 files changed, 189 insertions(+), 1 deletion(-)
create mode 100644 server/gpu/modal/reflector_diarizer.py
diff --git a/server/gpu/modal/reflector_diarizer.py b/server/gpu/modal/reflector_diarizer.py
new file mode 100644
index 00000000..7c316548
--- /dev/null
+++ b/server/gpu/modal/reflector_diarizer.py
@@ -0,0 +1,188 @@
+"""
+Reflector GPU backend - diarizer
+===================================
+"""
+
+import os
+
+import modal.gpu
+from modal import Image, Secret, Stub, asgi_app, method
+from pydantic import BaseModel
+
+PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0"
+MODEL_DIR = "/root/diarization_models"
+
+stub = Stub(name="reflector-diarizer")
+
+
+def migrate_cache_llm():
+ """
+ XXX The cache for model files in Transformers v4.22.0 has been updated.
+ Migrating your old cache. This is a one-time only operation. You can
+ interrupt this and resume the migration later on by calling
+ `transformers.utils.move_cache()`.
+ """
+ from transformers.utils.hub import move_cache
+
+ print("Moving LLM cache")
+ move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
+ print("LLM cache moved")
+
+
+def download_pyannote_audio():
+ from pyannote.audio import Pipeline
+ Pipeline.from_pretrained(
+ "pyannote/speaker-diarization-3.0",
+ cache_dir=MODEL_DIR,
+ use_auth_token="***REMOVED***"
+ )
+
+
+diarizer_image = (
+ Image.debian_slim(python_version="3.10.8")
+ .pip_install(
+ "pyannote.audio",
+ "requests",
+ "onnx",
+ "torchaudio",
+ "onnxruntime-gpu",
+ "torch==2.0.0",
+ "transformers==4.34.0",
+ "sentencepiece",
+ "protobuf",
+ "numpy",
+ "huggingface_hub",
+ "hf-transfer"
+ )
+ .run_function(migrate_cache_llm)
+ .run_function(download_pyannote_audio)
+ .env(
+ {
+ "LD_LIBRARY_PATH": (
+ "/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
+ "/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
+ )
+ }
+ )
+)
+
+
+@stub.cls(
+ gpu=modal.gpu.A100(memory=40),
+ timeout=60 * 10,
+ container_idle_timeout=60 * 3,
+ allow_concurrent_inputs=6,
+ image=diarizer_image,
+)
+class Diarizer:
+ def __enter__(self):
+ import torch
+ from pyannote.audio import Pipeline
+
+ self.use_gpu = torch.cuda.is_available()
+ self.device = "cuda" if self.use_gpu else "cpu"
+ self.diarization_pipeline = Pipeline.from_pretrained(
+ "pyannote/speaker-diarization-3.0",
+ cache_dir=MODEL_DIR
+ )
+ self.diarization_pipeline.to(torch.device(self.device))
+
+ @method()
+ def diarize(
+ self,
+ audio_data: str,
+ audio_suffix: str,
+ timestamp: float
+ ):
+ import tempfile
+
+ import torchaudio
+
+ with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
+ fp.write(audio_data)
+
+ print("Diarizing audio")
+ waveform, sample_rate = torchaudio.load(fp.name)
+ diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
+
+ words = []
+ for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
+ words.append(
+ {
+ "start": round(timestamp + diarization_segment.start, 3),
+ "end": round(timestamp + diarization_segment.end, 3),
+ "speaker": int(speaker[-2:])
+ }
+ )
+ print("Diarization complete")
+ return {
+ "diarization": words
+ }
+
+# -------------------------------------------------------------------
+# Web API
+# -------------------------------------------------------------------
+
+
+@stub.function(
+ timeout=60 * 10,
+ container_idle_timeout=60 * 3,
+ allow_concurrent_inputs=40,
+ secrets=[
+ Secret.from_name("reflector-gpu"),
+ ],
+ image=diarizer_image
+)
+@asgi_app()
+def web():
+ import requests
+ from fastapi import Depends, FastAPI, HTTPException, status
+ from fastapi.security import OAuth2PasswordBearer
+
+ diarizerstub = Diarizer()
+
+ app = FastAPI()
+
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+ def apikey_auth(apikey: str = Depends(oauth2_scheme)):
+ if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid API key",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ def validate_audio_file(audio_file_url: str):
+ # Check if the audio file exists
+ response = requests.head(audio_file_url, allow_redirects=True)
+ if response.status_code == 404:
+ raise HTTPException(
+ status_code=response.status_code,
+ detail="The audio file does not exist."
+ )
+
+ class DiarizationResponse(BaseModel):
+ result: dict
+
+ @app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
+ def diarize(
+ audio_file_url: str,
+ timestamp: float = 0.0
+ ) -> HTTPException | DiarizationResponse:
+ # Currently the uploaded files are in mp3 format
+ audio_suffix = "mp3"
+
+ print("Downloading audio file")
+ response = requests.get(audio_file_url, allow_redirects=True)
+ print("Audio file downloaded successfully")
+
+ func = diarizerstub.diarize.spawn(
+ audio_data=response.content,
+ audio_suffix=audio_suffix,
+ timestamp=timestamp
+ )
+ result = func.get()
+ return result
+
+ return app
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
index 53de2501..511b7f70 100644
--- a/server/reflector/processors/audio_diarization_modal.py
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -31,7 +31,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
follow_redirects=True,
)
response.raise_for_status()
- return response.json()["text"]
+ return response.json()["diarization"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
From 7ac6d2521737248f81aa0cd31483fbbe8216cedd Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 30 Nov 2023 17:30:08 +0100
Subject: [PATCH 24/27] server: add participant API
Also break out views into different files for easier reading
---
.../versions/125031f7cb78_participants.py | 30 +++
server/reflector/app.py | 10 +
server/reflector/db/transcripts.py | 56 ++++-
server/reflector/views/transcripts.py | 206 +-----------------
server/reflector/views/transcripts_audio.py | 109 +++++++++
.../views/transcripts_participants.py | 142 ++++++++++++
server/reflector/views/transcripts_webrtc.py | 37 ++++
.../reflector/views/transcripts_websocket.py | 53 +++++
server/reflector/views/types.py | 5 +
server/tests/test_transcripts_participants.py | 164 ++++++++++++++
10 files changed, 610 insertions(+), 202 deletions(-)
create mode 100644 server/migrations/versions/125031f7cb78_participants.py
create mode 100644 server/reflector/views/transcripts_audio.py
create mode 100644 server/reflector/views/transcripts_participants.py
create mode 100644 server/reflector/views/transcripts_webrtc.py
create mode 100644 server/reflector/views/transcripts_websocket.py
create mode 100644 server/reflector/views/types.py
create mode 100644 server/tests/test_transcripts_participants.py
diff --git a/server/migrations/versions/125031f7cb78_participants.py b/server/migrations/versions/125031f7cb78_participants.py
new file mode 100644
index 00000000..c345b083
--- /dev/null
+++ b/server/migrations/versions/125031f7cb78_participants.py
@@ -0,0 +1,30 @@
+"""participants
+
+Revision ID: 125031f7cb78
+Revises: 0fea6d96b096
+Create Date: 2023-11-30 15:56:03.341466
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '125031f7cb78'
+down_revision: Union[str, None] = '0fea6d96b096'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('transcript', 'participants')
+ # ### end Alembic commands ###
diff --git a/server/reflector/app.py b/server/reflector/app.py
index 5bfffeca..8f45efd5 100644
--- a/server/reflector/app.py
+++ b/server/reflector/app.py
@@ -13,6 +13,12 @@ from reflector.metrics import metrics_init
from reflector.settings import settings
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
+from reflector.views.transcripts_audio import router as transcripts_audio_router
+from reflector.views.transcripts_participants import (
+ router as transcripts_participants_router,
+)
+from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
+from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router
try:
@@ -60,6 +66,10 @@ metrics_init(app, instrumentator)
# register views
app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1")
+app.include_router(transcripts_audio_router, prefix="/v1")
+app.include_router(transcripts_participants_router, prefix="/v1")
+app.include_router(transcripts_websocket_router, prefix="/v1")
+app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 0fba82ef..44688eaa 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -7,7 +7,7 @@ from uuid import uuid4
import sqlalchemy
from fastapi import HTTPException
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, ConfigDict, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
@@ -27,6 +27,7 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
+ sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
@@ -112,6 +113,13 @@ class TranscriptEvent(BaseModel):
data: dict
+class TranscriptParticipant(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ id: str = Field(default_factory=generate_uuid4)
+ speaker: int | None
+ name: str
+
+
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
@@ -125,6 +133,7 @@ class Transcript(BaseModel):
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
+ participants: list[TranscriptParticipant] = []
source_language: str = "en"
target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private"
@@ -142,12 +151,34 @@ class Transcript(BaseModel):
else:
self.topics.append(topic)
+ def upsert_participant(self, participant: TranscriptParticipant):
+ index = next(
+ (i for i, p in enumerate(self.participants) if p.id == participant.id),
+ None,
+ )
+ if index is not None:
+ self.participants[index] = participant
+ else:
+ self.participants.append(participant)
+ return participant
+
+ def delete_participant(self, participant_id: str):
+ index = next(
+ (i for i, p in enumerate(self.participants) if p.id == participant_id),
+ None,
+ )
+ if index is not None:
+ del self.participants[index]
+
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
+ def participants_dump(self, mode="json"):
+ return [participant.model_dump(mode=mode) for participant in self.participants]
+
def unlink(self):
self.data_path.unlink(missing_ok=True)
@@ -410,5 +441,28 @@ class TranscriptController:
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
+ async def upsert_participant(
+ self,
+ transcript: Transcript,
+ participant: TranscriptParticipant,
+ ) -> TranscriptParticipant:
+ """
+ Add/update a participant to a transcript
+ """
+ result = transcript.upsert_participant(participant)
+ await self.update(transcript, {"participants": transcript.participants_dump()})
+ return result
+
+ async def delete_participant(
+ self,
+ transcript: Transcript,
+ participant_id: str,
+ ):
+ """
+ Delete a participant from a transcript
+ """
+ transcript.delete_participant(participant_id)
+ await self.update(transcript, {"participants": transcript.participants_dump()})
+
transcripts_controller = TranscriptController()
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 44b55629..9e62192b 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -1,33 +1,19 @@
from datetime import datetime, timedelta
from typing import Annotated, Literal, Optional
-import httpx
import reflector.auth as auth
-from fastapi import (
- APIRouter,
- Depends,
- HTTPException,
- Request,
- Response,
- WebSocket,
- WebSocketDisconnect,
- status,
-)
+from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
- AudioWaveform,
+ TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
-from reflector.ws_manager import get_ws_manager
-
-from ._range_requests_response import range_requests_response
-from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
@@ -62,6 +48,7 @@ class GetTranscript(BaseModel):
share_mode: str = Field("private")
source_language: str | None
target_language: str | None
+ participants: list[TranscriptParticipant] | None
class CreateTranscript(BaseModel):
@@ -77,6 +64,7 @@ class UpdateTranscript(BaseModel):
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
+ participants: Optional[list[TranscriptParticipant]] = Field(None)
class DeletionStatus(BaseModel):
@@ -192,19 +180,7 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
- values = {}
- if info.name is not None:
- values["name"] = info.name
- if info.locked is not None:
- values["locked"] = info.locked
- if info.long_summary is not None:
- values["long_summary"] = info.long_summary
- if info.short_summary is not None:
- values["short_summary"] = info.short_summary
- if info.title is not None:
- values["title"] = info.title
- if info.share_mode is not None:
- values["share_mode"] = info.share_mode
+ values = info.dict(exclude_unset=True)
await transcripts_controller.update(transcript, values)
return transcript
@@ -222,97 +198,6 @@ async def transcript_delete(
return DeletionStatus(status="ok")
-@router.get("/transcripts/{transcript_id}/audio/mp3")
-@router.head("/transcripts/{transcript_id}/audio/mp3")
-async def transcript_get_audio_mp3(
- request: Request,
- transcript_id: str,
- user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
- token: str | None = None,
-):
- user_id = user["sub"] if user else None
- if not user_id and token:
- unauthorized_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid or expired token",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
- user_id: str = payload.get("sub")
- except jwt.JWTError:
- raise unauthorized_exception
-
- transcript = await transcripts_controller.get_by_id_for_http(
- transcript_id, user_id=user_id
- )
-
- if transcript.audio_location == "storage":
- # proxy S3 file, to prevent issue with CORS
- url = await transcript.get_audio_url()
- headers = {}
-
- copy_headers = ["range", "accept-encoding"]
- for header in copy_headers:
- if header in request.headers:
- headers[header] = request.headers[header]
-
- async with httpx.AsyncClient() as client:
- resp = await client.request(request.method, url, headers=headers)
- return Response(
- content=resp.content,
- status_code=resp.status_code,
- headers=resp.headers,
- )
-
- if transcript.audio_location == "storage":
- # proxy S3 file, to prevent issue with CORS
- url = await transcript.get_audio_url()
- headers = {}
-
- copy_headers = ["range", "accept-encoding"]
- for header in copy_headers:
- if header in request.headers:
- headers[header] = request.headers[header]
-
- async with httpx.AsyncClient() as client:
- resp = await client.request(request.method, url, headers=headers)
- return Response(
- content=resp.content,
- status_code=resp.status_code,
- headers=resp.headers,
- )
-
- if not transcript.audio_mp3_filename.exists():
- raise HTTPException(status_code=500, detail="Audio not found")
-
- truncated_id = str(transcript.id).split("-")[0]
- filename = f"recording_{truncated_id}.mp3"
-
- return range_requests_response(
- request,
- transcript.audio_mp3_filename,
- content_type="audio/mpeg",
- content_disposition=f"attachment; filename={filename}",
- )
-
-
-@router.get("/transcripts/{transcript_id}/audio/waveform")
-async def transcript_get_audio_waveform(
- transcript_id: str,
- user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
-) -> AudioWaveform:
- user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id_for_http(
- transcript_id, user_id=user_id
- )
-
- if not transcript.audio_waveform_filename.exists():
- raise HTTPException(status_code=404, detail="Audio not found")
-
- return transcript.audio_waveform
-
-
@router.get(
"/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic],
@@ -330,84 +215,3 @@ async def transcript_get_topics(
return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
]
-
-
-# ==============================================================
-# Websocket
-# ==============================================================
-
-
-@router.get("/transcripts/{transcript_id}/events")
-async def transcript_get_websocket_events(transcript_id: str):
- pass
-
-
-@router.websocket("/transcripts/{transcript_id}/events")
-async def transcript_events_websocket(
- transcript_id: str,
- websocket: WebSocket,
- # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
-):
- # user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
-
- # connect to websocket manager
- # use ts:transcript_id as room id
- room_id = f"ts:{transcript_id}"
- ws_manager = get_ws_manager()
- await ws_manager.add_user_to_room(room_id, websocket)
-
- try:
- # on first connection, send all events only to the current user
- for event in transcript.events:
- # for now, do not send TRANSCRIPT or STATUS options - theses are live event
- # not necessary to be sent to the client; but keep the rest
- name = event.event
- if name in ("TRANSCRIPT", "STATUS"):
- continue
- await websocket.send_json(event.model_dump(mode="json"))
-
- # XXX if transcript is final (locked=True and status=ended)
- # XXX send a final event to the client and close the connection
-
- # endless loop to wait for new events
- # we do not have command system now,
- while True:
- await websocket.receive()
- except (RuntimeError, WebSocketDisconnect):
- await ws_manager.remove_user_from_room(room_id, websocket)
-
-
-# ==============================================================
-# Web RTC
-# ==============================================================
-
-
-@router.post("/transcripts/{transcript_id}/record/webrtc")
-async def transcript_record_webrtc(
- transcript_id: str,
- params: RtcOffer,
- request: Request,
- user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
-):
- user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id_for_http(
- transcript_id, user_id=user_id
- )
-
- if transcript.locked:
- raise HTTPException(status_code=400, detail="Transcript is locked")
-
- # create a pipeline runner
- from reflector.pipelines.main_live_pipeline import PipelineMainLive
-
- pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
-
- # FIXME do not allow multiple recording at the same time
- return await rtc_offer_base(
- params,
- request,
- pipeline_runner=pipeline_runner,
- )
diff --git a/server/reflector/views/transcripts_audio.py b/server/reflector/views/transcripts_audio.py
new file mode 100644
index 00000000..a174d992
--- /dev/null
+++ b/server/reflector/views/transcripts_audio.py
@@ -0,0 +1,109 @@
+"""
+Transcripts audio related endpoints
+===================================
+
+"""
+from typing import Annotated, Optional
+
+import httpx
+import reflector.auth as auth
+from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
+from jose import jwt
+from reflector.db.transcripts import AudioWaveform, transcripts_controller
+from reflector.settings import settings
+from reflector.views.transcripts import ALGORITHM
+
+from ._range_requests_response import range_requests_response
+
+router = APIRouter()
+
+
+@router.get("/transcripts/{transcript_id}/audio/mp3")
+@router.head("/transcripts/{transcript_id}/audio/mp3")
+async def transcript_get_audio_mp3(
+ request: Request,
+ transcript_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+ token: str | None = None,
+):
+ user_id = user["sub"] if user else None
+ if not user_id and token:
+ unauthorized_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or expired token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ try:
+ payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
+ user_id: str = payload.get("sub")
+ except jwt.JWTError:
+ raise unauthorized_exception
+
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ if transcript.audio_location == "storage":
+ # proxy S3 file, to prevent issue with CORS
+ url = await transcript.get_audio_url()
+ headers = {}
+
+ copy_headers = ["range", "accept-encoding"]
+ for header in copy_headers:
+ if header in request.headers:
+ headers[header] = request.headers[header]
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.request(request.method, url, headers=headers)
+ return Response(
+ content=resp.content,
+ status_code=resp.status_code,
+ headers=resp.headers,
+ )
+
+ if transcript.audio_location == "storage":
+ # proxy S3 file, to prevent issue with CORS
+ url = await transcript.get_audio_url()
+ headers = {}
+
+ copy_headers = ["range", "accept-encoding"]
+ for header in copy_headers:
+ if header in request.headers:
+ headers[header] = request.headers[header]
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.request(request.method, url, headers=headers)
+ return Response(
+ content=resp.content,
+ status_code=resp.status_code,
+ headers=resp.headers,
+ )
+
+ if not transcript.audio_mp3_filename.exists():
+ raise HTTPException(status_code=500, detail="Audio not found")
+
+ truncated_id = str(transcript.id).split("-")[0]
+ filename = f"recording_{truncated_id}.mp3"
+
+ return range_requests_response(
+ request,
+ transcript.audio_mp3_filename,
+ content_type="audio/mpeg",
+ content_disposition=f"attachment; filename={filename}",
+ )
+
+
+@router.get("/transcripts/{transcript_id}/audio/waveform")
+async def transcript_get_audio_waveform(
+ transcript_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> AudioWaveform:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ if not transcript.audio_waveform_filename.exists():
+ raise HTTPException(status_code=404, detail="Audio not found")
+
+ return transcript.audio_waveform
diff --git a/server/reflector/views/transcripts_participants.py b/server/reflector/views/transcripts_participants.py
new file mode 100644
index 00000000..318d6018
--- /dev/null
+++ b/server/reflector/views/transcripts_participants.py
@@ -0,0 +1,142 @@
+"""
+Transcript participants API endpoints
+=====================================
+
+"""
+from typing import Annotated, Optional
+
+import reflector.auth as auth
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel, ConfigDict, Field
+from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
+from reflector.views.types import DeletionStatus
+
+router = APIRouter()
+
+
+class Participant(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ id: str
+ speaker: int | None
+ name: str
+
+
+class CreateParticipant(BaseModel):
+ speaker: Optional[int] = Field(None)
+ name: str
+
+
+class UpdateParticipant(BaseModel):
+ speaker: Optional[int] = Field(None)
+ name: Optional[str] = Field(None)
+
+
+@router.get("/transcripts/{transcript_id}/participants")
+async def transcript_get_participants(
+ transcript_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> list[Participant]:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ return [
+ Participant.model_validate(participant)
+ for participant in transcript.participants
+ ]
+
+
+@router.post("/transcripts/{transcript_id}/participants")
+async def transcript_add_participant(
+ transcript_id: str,
+ participant: CreateParticipant,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> Participant:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ # ensure the speaker is unique
+ for p in transcript.participants:
+ if p.speaker == participant.speaker:
+ raise HTTPException(
+ status_code=400,
+ detail="Speaker already assigned",
+ )
+
+ obj = await transcripts_controller.upsert_participant(
+ transcript, TranscriptParticipant(**participant.dict())
+ )
+ return Participant.model_validate(obj)
+
+
+@router.get("/transcripts/{transcript_id}/participants/{participant_id}")
+async def transcript_get_participant(
+ transcript_id: str,
+ participant_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> Participant:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ for p in transcript.participants:
+ if p.id == participant_id:
+ return Participant.model_validate(p)
+
+ raise HTTPException(status_code=404, detail="Participant not found")
+
+
+@router.patch("/transcripts/{transcript_id}/participants/{participant_id}")
+async def transcript_update_participant(
+ transcript_id: str,
+ participant_id: str,
+ participant: UpdateParticipant,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> Participant:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ # ensure the speaker is unique
+ for p in transcript.participants:
+ if p.speaker == participant.speaker and p.id != participant_id:
+ raise HTTPException(
+ status_code=400,
+ detail="Speaker already assigned",
+ )
+
+ # find the participant
+ obj = None
+ for p in transcript.participants:
+ if p.id == participant_id:
+ obj = p
+ break
+
+ if not obj:
+ raise HTTPException(status_code=404, detail="Participant not found")
+
+ # update participant but just the fields that are set
+ fields = participant.dict(exclude_unset=True)
+ obj = obj.copy(update=fields)
+
+ await transcripts_controller.upsert_participant(transcript, obj)
+ return Participant.model_validate(obj)
+
+
+@router.delete("/transcripts/{transcript_id}/participants/{participant_id}")
+async def transcript_delete_participant(
+ transcript_id: str,
+ participant_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> DeletionStatus:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+ await transcripts_controller.delete_participant(transcript, participant_id)
+ return DeletionStatus(status="ok")
diff --git a/server/reflector/views/transcripts_webrtc.py b/server/reflector/views/transcripts_webrtc.py
new file mode 100644
index 00000000..af451411
--- /dev/null
+++ b/server/reflector/views/transcripts_webrtc.py
@@ -0,0 +1,37 @@
+from typing import Annotated, Optional
+
+import reflector.auth as auth
+from fastapi import APIRouter, Depends, HTTPException, Request
+from reflector.db.transcripts import transcripts_controller
+
+from .rtc_offer import RtcOffer, rtc_offer_base
+
+router = APIRouter()
+
+
+@router.post("/transcripts/{transcript_id}/record/webrtc")
+async def transcript_record_webrtc(
+ transcript_id: str,
+ params: RtcOffer,
+ request: Request,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+):
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ if transcript.locked:
+ raise HTTPException(status_code=400, detail="Transcript is locked")
+
+ # create a pipeline runner
+ from reflector.pipelines.main_live_pipeline import PipelineMainLive
+
+ pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
+
+ # FIXME do not allow multiple recording at the same time
+ return await rtc_offer_base(
+ params,
+ request,
+ pipeline_runner=pipeline_runner,
+ )
diff --git a/server/reflector/views/transcripts_websocket.py b/server/reflector/views/transcripts_websocket.py
new file mode 100644
index 00000000..65571aab
--- /dev/null
+++ b/server/reflector/views/transcripts_websocket.py
@@ -0,0 +1,53 @@
+"""
+Transcripts websocket API
+=========================
+
+"""
+from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
+from reflector.db.transcripts import transcripts_controller
+from reflector.ws_manager import get_ws_manager
+
+router = APIRouter()
+
+
+@router.get("/transcripts/{transcript_id}/events")
+async def transcript_get_websocket_events(transcript_id: str):
+ pass
+
+
+@router.websocket("/transcripts/{transcript_id}/events")
+async def transcript_events_websocket(
+ transcript_id: str,
+ websocket: WebSocket,
+ # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+):
+ # user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id(transcript_id)
+ if not transcript:
+ raise HTTPException(status_code=404, detail="Transcript not found")
+
+ # connect to websocket manager
+ # use ts:transcript_id as room id
+ room_id = f"ts:{transcript_id}"
+ ws_manager = get_ws_manager()
+ await ws_manager.add_user_to_room(room_id, websocket)
+
+ try:
+ # on first connection, send all events only to the current user
+ for event in transcript.events:
+ # for now, do not send TRANSCRIPT or STATUS options - theses are live event
+ # not necessary to be sent to the client; but keep the rest
+ name = event.event
+ if name in ("TRANSCRIPT", "STATUS"):
+ continue
+ await websocket.send_json(event.model_dump(mode="json"))
+
+ # XXX if transcript is final (locked=True and status=ended)
+ # XXX send a final event to the client and close the connection
+
+ # endless loop to wait for new events
+ # we do not have command system now,
+ while True:
+ await websocket.receive()
+ except (RuntimeError, WebSocketDisconnect):
+ await ws_manager.remove_user_from_room(room_id, websocket)
diff --git a/server/reflector/views/types.py b/server/reflector/views/types.py
new file mode 100644
index 00000000..70361131
--- /dev/null
+++ b/server/reflector/views/types.py
@@ -0,0 +1,5 @@
+from pydantic import BaseModel
+
+
+class DeletionStatus(BaseModel):
+ status: str
diff --git a/server/tests/test_transcripts_participants.py b/server/tests/test_transcripts_participants.py
new file mode 100644
index 00000000..b55b16a8
--- /dev/null
+++ b/server/tests/test_transcripts_participants.py
@@ -0,0 +1,164 @@
+import pytest
+from httpx import AsyncClient
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+
+ # create a participant
+ transcript_id = response.json()["id"]
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants", json={"name": "test"}
+ )
+ assert response.status_code == 200
+ assert response.json()["id"] is not None
+ assert response.json()["speaker"] is None
+ assert response.json()["name"] == "test"
+
+ # create another one with a speaker
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test2", "speaker": 1},
+ )
+ assert response.status_code == 200
+ assert response.json()["id"] is not None
+ assert response.json()["speaker"] == 1
+ assert response.json()["name"] == "test2"
+
+ # get all participants via transcript
+ response = await ac.get(f"/transcripts/{transcript_id}")
+ assert response.status_code == 200
+ assert len(response.json()["participants"]) == 2
+
+ # get participants via participants endpoint
+ response = await ac.get(f"/transcripts/{transcript_id}/participants")
+ assert response.status_code == 200
+ assert len(response.json()) == 2
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants_same_speaker():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+ transcript_id = response.json()["id"]
+
+ # create a participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test", "speaker": 1},
+ )
+ assert response.status_code == 200
+ assert response.json()["speaker"] == 1
+
+ # create another one with the same speaker
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test2", "speaker": 1},
+ )
+ assert response.status_code == 400
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants_update_name():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+ transcript_id = response.json()["id"]
+
+ # create a participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test", "speaker": 1},
+ )
+ assert response.status_code == 200
+ assert response.json()["speaker"] == 1
+
+ # update the participant
+ participant_id = response.json()["id"]
+ response = await ac.patch(
+ f"/transcripts/{transcript_id}/participants/{participant_id}",
+ json={"name": "test2"},
+ )
+ assert response.status_code == 200
+ assert response.json()["name"] == "test2"
+
+ # verify the participant was updated
+ response = await ac.get(
+ f"/transcripts/{transcript_id}/participants/{participant_id}"
+ )
+ assert response.status_code == 200
+ assert response.json()["name"] == "test2"
+
+ # verify the participant was updated in transcript
+ response = await ac.get(f"/transcripts/{transcript_id}")
+ assert response.status_code == 200
+ assert len(response.json()["participants"]) == 1
+ assert response.json()["participants"][0]["name"] == "test2"
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants_update_speaker():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+ transcript_id = response.json()["id"]
+
+ # create a participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test", "speaker": 1},
+ )
+ assert response.status_code == 200
+ participant1_id = response.json()["id"]
+
+ # create another participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test2", "speaker": 2},
+ )
+ assert response.status_code == 200
+ participant2_id = response.json()["id"]
+
+ # update the participant, refused as speaker is already taken
+ response = await ac.patch(
+ f"/transcripts/{transcript_id}/participants/{participant2_id}",
+ json={"speaker": 1},
+ )
+ assert response.status_code == 400
+
+ # delete the participant 1
+ response = await ac.delete(
+ f"/transcripts/{transcript_id}/participants/{participant1_id}"
+ )
+ assert response.status_code == 200
+
+ # update the participant 2 again, should be accepted now
+ response = await ac.patch(
+ f"/transcripts/{transcript_id}/participants/{participant2_id}",
+ json={"speaker": 1},
+ )
+ assert response.status_code == 200
+
+ # ensure participant2 name is still there
+ response = await ac.get(
+ f"/transcripts/{transcript_id}/participants/{participant2_id}"
+ )
+ assert response.status_code == 200
+ assert response.json()["name"] == "test2"
+ assert response.json()["speaker"] == 1
From 8b1b71940f18ddf7a669845acd55b444dd5d005d Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 30 Nov 2023 19:25:09 +0100
Subject: [PATCH 25/27] hotfix/server: update diarization settings to increase
timeout, reduce idle timeout on the minimum
---
server/gpu/modal/reflector_diarizer.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/server/gpu/modal/reflector_diarizer.py b/server/gpu/modal/reflector_diarizer.py
index 7c316548..b1989a11 100644
--- a/server/gpu/modal/reflector_diarizer.py
+++ b/server/gpu/modal/reflector_diarizer.py
@@ -69,9 +69,9 @@ diarizer_image = (
@stub.cls(
gpu=modal.gpu.A100(memory=40),
- timeout=60 * 10,
- container_idle_timeout=60 * 3,
- allow_concurrent_inputs=6,
+ timeout=60 * 30,
+ container_idle_timeout=60,
+ allow_concurrent_inputs=1,
image=diarizer_image,
)
class Diarizer:
From f9771427e2f7c363ae775006d61e7cb3f506faf4 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 30 Nov 2023 19:43:19 +0100
Subject: [PATCH 26/27] server: add healthcheck for worker
---
README.md | 14 +++++++++++---
server/reflector/settings.py | 3 +++
server/reflector/worker/app.py | 15 +++++++++++++++
server/reflector/worker/healthcheck.py | 18 ++++++++++++++++++
server/runserver.sh | 2 ++
5 files changed, 49 insertions(+), 3 deletions(-)
create mode 100644 server/reflector/worker/healthcheck.py
diff --git a/README.md b/README.md
index cb75c76b..8830c79e 100644
--- a/README.md
+++ b/README.md
@@ -133,17 +133,25 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
```
-### Start the project
+### Start the API/Backend
-Use:
+Start the API server:
```bash
poetry run python3 -m reflector.app
```
-And start the background worker
+Start the background worker:
+```bash
celery -A reflector.worker.app worker --loglevel=info
+```
+
+For crontab (only healthcheck for now), start the celery beat:
+
+```bash
+celery -A reflector.worker.app beat
+```
#### Using docker
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index 2c68c4e5..d0ddc91a 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -128,5 +128,8 @@ class Settings(BaseSettings):
# Profiling
PROFILING: bool = False
+ # Healthcheck
+ HEALTHCHECK_URL: str | None = None
+
settings = Settings()
diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py
index e1000364..689623ce 100644
--- a/server/reflector/worker/app.py
+++ b/server/reflector/worker/app.py
@@ -1,6 +1,8 @@
+import structlog
from celery import Celery
from reflector.settings import settings
+logger = structlog.get_logger(__name__)
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
@@ -8,5 +10,18 @@ app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
+ "reflector.worker.healthcheck",
]
)
+
+# crontab
+app.conf.beat_schedule = {}
+
+if settings.HEALTHCHECK_URL:
+ app.conf.beat_schedule["healthcheck_ping"] = {
+ "task": "reflector.worker.healthcheck.healthcheck_ping",
+ "schedule": 60.0 * 10,
+ }
+ logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
+else:
+ logger.warning("Healthcheck disabled, no url configured")
diff --git a/server/reflector/worker/healthcheck.py b/server/reflector/worker/healthcheck.py
new file mode 100644
index 00000000..e4ce6bc3
--- /dev/null
+++ b/server/reflector/worker/healthcheck.py
@@ -0,0 +1,18 @@
+import httpx
+import structlog
+from celery import shared_task
+from reflector.settings import settings
+
+logger = structlog.get_logger(__name__)
+
+
+@shared_task
+def healthcheck_ping():
+ url = settings.HEALTHCHECK_URL
+ if not url:
+ return
+ try:
+ print("pinging healthcheck url", url)
+ httpx.get(url, timeout=10)
+ except Exception as e:
+ logger.error("healthcheck_ping", error=str(e))
diff --git a/server/runserver.sh b/server/runserver.sh
index b0c3f138..31cce123 100755
--- a/server/runserver.sh
+++ b/server/runserver.sh
@@ -9,6 +9,8 @@ if [ "${ENTRYPOINT}" = "server" ]; then
python -m reflector.app
elif [ "${ENTRYPOINT}" = "worker" ]; then
celery -A reflector.worker.app worker --loglevel=info
+elif [ "${ENTRYPOINT}" = "beat" ]; then
+ celery -A reflector.worker.app beat --loglevel=info
else
echo "Unknown command"
fi
From 84a1350df7eca523743e5e602ad73cecab058ccf Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 1 Dec 2023 18:18:09 +0100
Subject: [PATCH 27/27] hotfix/server: fix participants loading on old meetings
---
server/reflector/db/transcripts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 44688eaa..970393d5 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -133,7 +133,7 @@ class Transcript(BaseModel):
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
- participants: list[TranscriptParticipant] = []
+ participants: list[TranscriptParticipant] | None = []
source_language: str = "en"
target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private"