From b3232543763364dfd93fbae2c56ea89e385f91a6 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 19 Oct 2023 21:05:13 +0200
Subject: [PATCH 01/41] server: move out profanity filter to transcript, and
implement segmentation
---
.../reflector/processors/audio_transcript.py | 10 ---
.../processors/audio_transcript_modal.py | 3 -
.../processors/audio_transcript_whisper.py | 1 -
.../reflector/processors/transcript_liner.py | 1 -
server/reflector/processors/types.py | 65 ++++++++++++++++++-
server/reflector/views/transcripts.py | 17 ++++-
6 files changed, 78 insertions(+), 19 deletions(-)
diff --git a/server/reflector/processors/audio_transcript.py b/server/reflector/processors/audio_transcript.py
index f029b587..3f9dc85b 100644
--- a/server/reflector/processors/audio_transcript.py
+++ b/server/reflector/processors/audio_transcript.py
@@ -1,6 +1,4 @@
-from profanityfilter import ProfanityFilter
from prometheus_client import Counter, Histogram
-
from reflector.processors.base import Processor
from reflector.processors.types import AudioFile, Transcript
@@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor):
self.m_transcript_call = self.m_transcript_call.labels(name)
self.m_transcript_success = self.m_transcript_success.labels(name)
self.m_transcript_failure = self.m_transcript_failure.labels(name)
- self.profanity_filter = ProfanityFilter()
- self.profanity_filter.set_censor("*")
super().__init__(*args, **kwargs)
async def _push(self, data: AudioFile):
@@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor):
async def _transcript(self, data: AudioFile):
raise NotImplementedError
-
- def filter_profanity(self, text: str) -> str:
- """
- Remove censored words from the transcript
- """
- return self.profanity_filter.censor(text)
diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py
index 201ed9d4..23c9d74e 100644
--- a/server/reflector/processors/audio_transcript_modal.py
+++ b/server/reflector/processors/audio_transcript_modal.py
@@ -48,10 +48,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
)
response.raise_for_status()
result = response.json()
- text = result["text"][source_language]
- text = self.filter_profanity(text)
transcript = Transcript(
- text=text,
words=[
Word(
text=word["text"],
diff --git a/server/reflector/processors/audio_transcript_whisper.py b/server/reflector/processors/audio_transcript_whisper.py
index e3bd595b..cd96e01a 100644
--- a/server/reflector/processors/audio_transcript_whisper.py
+++ b/server/reflector/processors/audio_transcript_whisper.py
@@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
ts = data.timestamp
for segment in segments:
- transcript.text += segment.text
for word in segment.words:
transcript.words.append(
Word(
diff --git a/server/reflector/processors/transcript_liner.py b/server/reflector/processors/transcript_liner.py
index c1aa14a0..b4e7b5e3 100644
--- a/server/reflector/processors/transcript_liner.py
+++ b/server/reflector/processors/transcript_liner.py
@@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor):
# cut to the next .
partial = Transcript(words=[])
for word in self.transcript.words[:]:
- partial.text += word.text
partial.words.append(word)
if not self.is_sentence_terminated(word.text):
continue
diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py
index e867becf..686c5785 100644
--- a/server/reflector/processors/types.py
+++ b/server/reflector/processors/types.py
@@ -2,8 +2,12 @@ import io
import tempfile
from pathlib import Path
+from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
+profanity_filter = ProfanityFilter()
+profanity_filter.set_censor("*")
+
class AudioFile(BaseModel):
name: str
@@ -43,13 +47,29 @@ class Word(BaseModel):
text: str
start: float
end: float
+ speaker: int = 0
+
+
+class TranscriptSegment(BaseModel):
+ text: str
+ start: float
+ speaker: int = 0
class Transcript(BaseModel):
- text: str = ""
translation: str | None = None
words: list[Word] = None
+ @property
+ def raw_text(self):
+ # Uncensored text
+ return "".join([word.text for word in self.words])
+
+ @property
+ def text(self):
+ # Censored text
+ return profanity_filter.censor(self.raw_text).strip()
+
@property
def human_timestamp(self):
minutes = int(self.timestamp / 60)
@@ -74,7 +94,6 @@ class Transcript(BaseModel):
self.words = other.words
else:
self.words.extend(other.words)
- self.text += other.text
def add_offset(self, offset: float):
for word in self.words:
@@ -87,6 +106,48 @@ class Transcript(BaseModel):
]
return Transcript(text=self.text, translation=self.translation, words=words)
+ def as_segments(self):
+ # from a list of word, create a list of segments
+ # join the word that are less than 2 seconds apart
+ # but separate if the speaker changes, or if the punctuation is a . , ; : ? !
+ segments = []
+ current_segment = None
+ last_word = None
+ BLANK_TIME_SECS = 2
+ MAX_SEGMENT_LENGTH = 80
+ for word in self.words:
+ if current_segment is None:
+ current_segment = TranscriptSegment(
+ text=word.text,
+ start=word.start,
+ speaker=word.speaker,
+ )
+ continue
+ is_blank = False
+ if last_word:
+ is_blank = word.start - last_word.end > BLANK_TIME_SECS
+ if (
+ word.speaker != current_segment.speaker
+ or (
+ word.text in ".;:?!…"
+ and len(current_segment.text) > MAX_SEGMENT_LENGTH
+ )
+ or is_blank
+ ):
+ # check which condition triggered
+ segments.append(current_segment)
+ current_segment = TranscriptSegment(
+ text=word.text,
+ start=word.start,
+ speaker=word.speaker,
+ )
+ else:
+ current_segment.text += word.text
+ last_word = word
+ if current_segment:
+ segments.append(current_segment)
+ return segments
+
class TitleSummary(BaseModel):
title: str
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index a7e01b8c..0a068c17 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -49,12 +49,18 @@ class TranscriptText(BaseModel):
translation: str | None
+class TranscriptSegmentTopic(BaseModel):
+ speaker: int
+ text: str
+ timestamp: float
+
+
class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4)
title: str
summary: str
- transcript: str | None = None
timestamp: float
+ segments: list[TranscriptSegmentTopic] = []
class TranscriptFinalShortSummary(BaseModel):
@@ -523,8 +529,15 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
topic = TranscriptTopic(
title=data.title,
summary=data.summary,
- transcript=data.transcript.text,
timestamp=data.timestamp,
+ segments=[
+ TranscriptSegmentTopic(
+ speaker=segment.speaker,
+ text=segment.text,
+ timestamp=segment.start,
+ )
+ for segment in data.transcript.as_segments()
+ ],
)
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
From 6d074ed4570dbf85b304d7d5771ca6fe94c02114 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 19 Oct 2023 21:05:49 +0200
Subject: [PATCH 02/41] www: update frontend to support new transcript format
in topics
---
.../test_processor_transcript_segment.py | 146 ++++++++++++++++++
www/app/[domain]/transcripts/topicList.tsx | 12 +-
.../[domain]/transcripts/webSocketTypes.ts | 8 +-
www/app/api/.openapi-generator/FILES | 1 +
www/app/api/models/TranscriptSegmentTopic.ts | 88 +++++++++++
www/app/api/models/TranscriptTopic.ts | 8 +-
www/app/api/models/index.ts | 1 +
7 files changed, 258 insertions(+), 6 deletions(-)
create mode 100644 server/tests/test_processor_transcript_segment.py
create mode 100644 www/app/api/models/TranscriptSegmentTopic.ts
diff --git a/server/tests/test_processor_transcript_segment.py b/server/tests/test_processor_transcript_segment.py
new file mode 100644
index 00000000..3bb6182f
--- /dev/null
+++ b/server/tests/test_processor_transcript_segment.py
@@ -0,0 +1,146 @@
+def test_processor_transcript_segment():
+ from reflector.processors.types import Transcript, Word
+
+ transcript = Transcript(
+ words=[
+ Word(text=" the", start=5.12, end=5.48, speaker=0),
+ Word(text=" different", start=5.48, end=5.8, speaker=0),
+ Word(text=" projects", start=5.8, end=6.3, speaker=0),
+ Word(text=" that", start=6.3, end=6.5, speaker=0),
+ Word(text=" are", start=6.5, end=6.58, speaker=0),
+ Word(text=" going", start=6.58, end=6.82, speaker=0),
+ Word(text=" on", start=6.82, end=7.26, speaker=0),
+ Word(text=" to", start=7.26, end=7.4, speaker=0),
+ Word(text=" give", start=7.4, end=7.54, speaker=0),
+ Word(text=" you", start=7.54, end=7.9, speaker=0),
+ Word(text=" context", start=7.9, end=8.24, speaker=0),
+ Word(text=" and", start=8.24, end=8.66, speaker=0),
+ Word(text=" I", start=8.66, end=8.72, speaker=0),
+ Word(text=" think", start=8.72, end=8.82, speaker=0),
+ Word(text=" that's", start=8.82, end=9.04, speaker=0),
+ Word(text=" what", start=9.04, end=9.12, speaker=0),
+ Word(text=" we'll", start=9.12, end=9.24, speaker=0),
+ Word(text=" do", start=9.24, end=9.32, speaker=0),
+ Word(text=" this", start=9.32, end=9.52, speaker=0),
+ Word(text=" week.", start=9.52, end=9.76, speaker=0),
+ Word(text=" Um,", start=10.24, end=10.62, speaker=0),
+ Word(text=" so,", start=11.36, end=11.94, speaker=0),
+ Word(text=" um,", start=12.46, end=12.92, speaker=0),
+ Word(text=" what", start=13.74, end=13.94, speaker=0),
+ Word(text=" we're", start=13.94, end=14.1, speaker=0),
+ Word(text=" going", start=14.1, end=14.24, speaker=0),
+ Word(text=" to", start=14.24, end=14.34, speaker=0),
+ Word(text=" do", start=14.34, end=14.8, speaker=0),
+ Word(text=" at", start=14.8, end=14.98, speaker=0),
+ Word(text=" H", start=14.98, end=15.04, speaker=0),
+ Word(text=" of", start=15.04, end=15.16, speaker=0),
+ Word(text=" you,", start=15.16, end=15.26, speaker=0),
+ Word(text=" maybe.", start=15.28, end=15.34, speaker=0),
+ Word(text=" you", start=15.36, end=15.52, speaker=0),
+ Word(text=" can", start=15.52, end=15.62, speaker=0),
+ Word(text=" introduce", start=15.62, end=15.98, speaker=0),
+ Word(text=" yourself", start=15.98, end=16.42, speaker=0),
+ Word(text=" to", start=16.42, end=16.68, speaker=0),
+ Word(text=" the", start=16.68, end=16.72, speaker=0),
+ Word(text=" team", start=16.72, end=17.52, speaker=0),
+ Word(text=" quickly", start=17.87, end=18.65, speaker=0),
+ Word(text=" and", start=18.65, end=19.63, speaker=0),
+ Word(text=" Oh,", start=20.91, end=21.55, speaker=0),
+ Word(text=" this", start=21.67, end=21.83, speaker=0),
+ Word(text=" is", start=21.83, end=22.17, speaker=0),
+ Word(text=" a", start=22.17, end=22.35, speaker=0),
+ Word(text=" reflector", start=22.35, end=22.89, speaker=0),
+ Word(text=" translating", start=22.89, end=23.33, speaker=0),
+ Word(text=" into", start=23.33, end=23.73, speaker=0),
+ Word(text=" French", start=23.73, end=23.95, speaker=0),
+ Word(text=" for", start=23.95, end=24.15, speaker=0),
+ Word(text=" me.", start=24.15, end=24.43, speaker=0),
+ Word(text=" This", start=27.87, end=28.19, speaker=0),
+ Word(text=" is", start=28.19, end=28.45, speaker=0),
+ Word(text=" all", start=28.45, end=28.79, speaker=0),
+ Word(text=" the", start=28.79, end=29.15, speaker=0),
+ Word(text=" way,", start=29.15, end=29.15, speaker=0),
+ Word(text=" please,", start=29.53, end=29.59, speaker=0),
+ Word(text=" please,", start=29.73, end=29.77, speaker=0),
+ Word(text=" please,", start=29.77, end=29.83, speaker=0),
+ Word(text=" please.", start=29.83, end=29.97, speaker=0),
+ Word(text=" Yeah,", start=29.97, end=30.17, speaker=0),
+ Word(text=" that's", start=30.25, end=30.33, speaker=0),
+ Word(text=" all", start=30.33, end=30.49, speaker=0),
+ Word(text=" it's", start=30.49, end=30.69, speaker=0),
+ Word(text=" right.", start=30.69, end=30.69, speaker=0),
+ Word(text=" Right.", start=30.72, end=30.98, speaker=1),
+ Word(text=" Yeah,", start=31.56, end=31.72, speaker=2),
+ Word(text=" that's", start=31.86, end=31.98, speaker=2),
+ Word(text=" right.", start=31.98, end=32.2, speaker=2),
+ Word(text=" Because", start=32.38, end=32.46, speaker=0),
+ Word(text=" I", start=32.46, end=32.58, speaker=0),
+ Word(text=" thought", start=32.58, end=32.78, speaker=0),
+ Word(text=" I'd", start=32.78, end=33.0, speaker=0),
+ Word(text=" be", start=33.0, end=33.02, speaker=0),
+ Word(text=" able", start=33.02, end=33.18, speaker=0),
+ Word(text=" to", start=33.18, end=33.34, speaker=0),
+ Word(text=" pull", start=33.34, end=33.52, speaker=0),
+ Word(text=" out.", start=33.52, end=33.68, speaker=0),
+ Word(text=" Yeah,", start=33.7, end=33.9, speaker=0),
+ Word(text=" that", start=33.9, end=34.02, speaker=0),
+ Word(text=" was", start=34.02, end=34.24, speaker=0),
+ Word(text=" the", start=34.24, end=34.34, speaker=0),
+ Word(text=" one", start=34.34, end=34.44, speaker=0),
+ Word(text=" before", start=34.44, end=34.7, speaker=0),
+ Word(text=" that.", start=34.7, end=35.24, speaker=0),
+ Word(text=" Friends,", start=35.84, end=36.46, speaker=0),
+ Word(text=" if", start=36.64, end=36.7, speaker=0),
+ Word(text=" you", start=36.7, end=36.7, speaker=0),
+ Word(text=" have", start=36.7, end=37.24, speaker=0),
+ Word(text=" tell", start=37.24, end=37.44, speaker=0),
+ Word(text=" us", start=37.44, end=37.68, speaker=0),
+ Word(text=" if", start=37.68, end=37.82, speaker=0),
+ Word(text=" it's", start=37.82, end=38.04, speaker=0),
+ Word(text=" good,", start=38.04, end=38.58, speaker=0),
+ Word(text=" exceptionally", start=38.96, end=39.1, speaker=0),
+ Word(text=" good", start=39.1, end=39.6, speaker=0),
+ Word(text=" and", start=39.6, end=39.86, speaker=0),
+ Word(text=" tell", start=39.86, end=40.0, speaker=0),
+ Word(text=" us", start=40.0, end=40.06, speaker=0),
+ Word(text=" when", start=40.06, end=40.2, speaker=0),
+ Word(text=" it's", start=40.2, end=40.34, speaker=0),
+ Word(text=" exceptionally", start=40.34, end=40.6, speaker=0),
+ Word(text=" bad.", start=40.6, end=40.94, speaker=0),
+ Word(text=" We", start=40.96, end=41.26, speaker=0),
+ Word(text=" don't", start=41.26, end=41.44, speaker=0),
+ Word(text=" need", start=41.44, end=41.66, speaker=0),
+ Word(text=" that", start=41.66, end=41.82, speaker=0),
+ Word(text=" at", start=41.82, end=41.94, speaker=0),
+ Word(text=" the", start=41.94, end=41.98, speaker=0),
+ Word(text=" middle", start=41.98, end=42.18, speaker=0),
+ Word(text=" of", start=42.18, end=42.36, speaker=0),
+ Word(text=" age.", start=42.36, end=42.7, speaker=0),
+ Word(text=" Okay,", start=43.26, end=43.44, speaker=0),
+ Word(text=" yeah,", start=43.68, end=43.76, speaker=0),
+ Word(text=" that", start=43.78, end=44.3, speaker=0),
+ Word(text=" sentence", start=44.3, end=44.72, speaker=0),
+ Word(text=" right", start=44.72, end=45.1, speaker=0),
+ Word(text=" before.", start=45.1, end=45.56, speaker=0),
+ Word(text=" it", start=46.08, end=46.36, speaker=0),
+ Word(text=" realizing", start=46.36, end=47.0, speaker=0),
+ Word(text=" that", start=47.0, end=47.28, speaker=0),
+ Word(text=" I", start=47.28, end=47.28, speaker=0),
+ Word(text=" was", start=47.28, end=47.64, speaker=0),
+ Word(text=" saying", start=47.64, end=48.06, speaker=0),
+ Word(text=" that", start=48.06, end=48.44, speaker=0),
+ Word(text=" it's", start=48.44, end=48.54, speaker=0),
+ Word(text=" interesting", start=48.54, end=48.78, speaker=0),
+ Word(text=" that", start=48.78, end=48.96, speaker=0),
+ Word(text=" it's", start=48.96, end=49.08, speaker=0),
+ Word(text=" translating", start=49.08, end=49.32, speaker=0),
+ Word(text=" the", start=49.32, end=49.56, speaker=0),
+ Word(text=" French", start=49.56, end=49.76, speaker=0),
+ Word(text=" was", start=49.76, end=50.16, speaker=0),
+ Word(text=" completely", start=50.16, end=50.4, speaker=0),
+ Word(text=" wrong.", start=50.4, end=50.7, speaker=0),
+ ]
+ )
+
+ for segment in transcript.as_segments():
+ print(segment)
diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx
index e5de09c8..4fedacb0 100644
--- a/www/app/[domain]/transcripts/topicList.tsx
+++ b/www/app/[domain]/transcripts/topicList.tsx
@@ -103,7 +103,17 @@ export function TopicList({
/>
{activeTopic?.id == topic.id && (
- {topic.transcript}
+
+ {topic.segments.map((segment, index) => (
+
+ [{formatTime(segment.timestamp)}] Speaker{" "}
+ {segment.speaker}: {segment.text}
+
+ ))}
+
)}
))}
diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts
index 450b3b1c..3e02e26a 100644
--- a/www/app/[domain]/transcripts/webSocketTypes.ts
+++ b/www/app/[domain]/transcripts/webSocketTypes.ts
@@ -1,9 +1,15 @@
+export type SegmentTopic = {
+ speaker: number;
+ start: number;
+ text: string;
+};
+
export type Topic = {
timestamp: number;
title: string;
- transcript: string;
summary: string;
id: string;
+ segments: SegmentTopic[];
};
export type Transcript = {
diff --git a/www/app/api/.openapi-generator/FILES b/www/app/api/.openapi-generator/FILES
index 16763a8d..9192cb1f 100644
--- a/www/app/api/.openapi-generator/FILES
+++ b/www/app/api/.openapi-generator/FILES
@@ -8,6 +8,7 @@ models/GetTranscript.ts
models/HTTPValidationError.ts
models/PageGetTranscript.ts
models/RtcOffer.ts
+models/TranscriptSegmentTopic.ts
models/TranscriptTopic.ts
models/UpdateTranscript.ts
models/UserInfo.ts
diff --git a/www/app/api/models/TranscriptSegmentTopic.ts b/www/app/api/models/TranscriptSegmentTopic.ts
new file mode 100644
index 00000000..73496a67
--- /dev/null
+++ b/www/app/api/models/TranscriptSegmentTopic.ts
@@ -0,0 +1,88 @@
+/* tslint:disable */
+/* eslint-disable */
+/**
+ * FastAPI
+ * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
+ *
+ * The version of the OpenAPI document: 0.1.0
+ *
+ *
+ * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
+ * https://openapi-generator.tech
+ * Do not edit the class manually.
+ */
+
+import { exists, mapValues } from "../runtime";
+/**
+ *
+ * @export
+ * @interface TranscriptSegmentTopic
+ */
+export interface TranscriptSegmentTopic {
+ /**
+ *
+ * @type {any}
+ * @memberof TranscriptSegmentTopic
+ */
+ speaker: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof TranscriptSegmentTopic
+ */
+ text: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof TranscriptSegmentTopic
+ */
+ timestamp: any | null;
+}
+
+/**
+ * Check if a given object implements the TranscriptSegmentTopic interface.
+ */
+export function instanceOfTranscriptSegmentTopic(value: object): boolean {
+ let isInstance = true;
+ isInstance = isInstance && "speaker" in value;
+ isInstance = isInstance && "text" in value;
+ isInstance = isInstance && "timestamp" in value;
+
+ return isInstance;
+}
+
+export function TranscriptSegmentTopicFromJSON(
+ json: any,
+): TranscriptSegmentTopic {
+ return TranscriptSegmentTopicFromJSONTyped(json, false);
+}
+
+export function TranscriptSegmentTopicFromJSONTyped(
+ json: any,
+ ignoreDiscriminator: boolean,
+): TranscriptSegmentTopic {
+ if (json === undefined || json === null) {
+ return json;
+ }
+ return {
+ speaker: json["speaker"],
+ text: json["text"],
+ timestamp: json["timestamp"],
+ };
+}
+
+export function TranscriptSegmentTopicToJSON(
+ value?: TranscriptSegmentTopic | null,
+): any {
+ if (value === undefined) {
+ return undefined;
+ }
+ if (value === null) {
+ return null;
+ }
+ return {
+ speaker: value.speaker,
+ text: value.text,
+ timestamp: value.timestamp,
+ };
+}
diff --git a/www/app/api/models/TranscriptTopic.ts b/www/app/api/models/TranscriptTopic.ts
index 8b395374..f22496b2 100644
--- a/www/app/api/models/TranscriptTopic.ts
+++ b/www/app/api/models/TranscriptTopic.ts
@@ -42,13 +42,13 @@ export interface TranscriptTopic {
* @type {any}
* @memberof TranscriptTopic
*/
- transcript?: any | null;
+ timestamp: any | null;
/**
*
* @type {any}
* @memberof TranscriptTopic
*/
- timestamp: any | null;
+ segments?: any | null;
}
/**
@@ -78,8 +78,8 @@ export function TranscriptTopicFromJSONTyped(
id: !exists(json, "id") ? undefined : json["id"],
title: json["title"],
summary: json["summary"],
- transcript: !exists(json, "transcript") ? undefined : json["transcript"],
timestamp: json["timestamp"],
+ segments: !exists(json, "segments") ? undefined : json["segments"],
};
}
@@ -94,7 +94,7 @@ export function TranscriptTopicToJSON(value?: TranscriptTopic | null): any {
id: value.id,
title: value.title,
summary: value.summary,
- transcript: value.transcript,
timestamp: value.timestamp,
+ segments: value.segments,
};
}
diff --git a/www/app/api/models/index.ts b/www/app/api/models/index.ts
index 99874641..6a7d7318 100644
--- a/www/app/api/models/index.ts
+++ b/www/app/api/models/index.ts
@@ -7,6 +7,7 @@ export * from "./GetTranscript";
export * from "./HTTPValidationError";
export * from "./PageGetTranscript";
export * from "./RtcOffer";
+export * from "./TranscriptSegmentTopic";
export * from "./TranscriptTopic";
export * from "./UpdateTranscript";
export * from "./UserInfo";
From 00eb9bbf3c5ccdb8f7f3f9364ae952d821897759 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 20 Oct 2023 16:06:35 +0200
Subject: [PATCH 03/41] server: improve split algorithm
---
server/reflector/processors/types.py | 35 ++++++++++------------------
1 file changed, 12 insertions(+), 23 deletions(-)
diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py
index 686c5785..ba0cccf9 100644
--- a/server/reflector/processors/types.py
+++ b/server/reflector/processors/types.py
@@ -1,10 +1,13 @@
import io
+import re
import tempfile
from pathlib import Path
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
+PUNC_RE = re.compile(r"[.;:?!…]")
+
profanity_filter = ProfanityFilter()
profanity_filter.set_censor("*")
@@ -106,15 +109,14 @@ class Transcript(BaseModel):
]
return Transcript(text=self.text, translation=self.translation, words=words)
- def as_segments(self):
+ def as_segments(self) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
- last_word = None
- BLANK_TIME_SECS = 2
- MAX_SEGMENT_LENGTH = 80
+ MAX_SEGMENT_LENGTH = 120
+
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
@@ -123,27 +125,14 @@ class Transcript(BaseModel):
speaker=word.speaker,
)
continue
- is_blank = False
- if last_word:
- is_blank = word.start - last_word.end > BLANK_TIME_SECS
- if (
- word.speaker != current_segment.speaker
- or (
- word.text in ".;:?!…"
- and len(current_segment.text) > MAX_SEGMENT_LENGTH
- )
- or is_blank
+ current_segment.text += word.text
+
+ have_punc = PUNC_RE.search(word.text)
+ if word.speaker != current_segment.speaker or (
+ have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH)
):
- # check which condition triggered
segments.append(current_segment)
- current_segment = TranscriptSegment(
- text=word.text,
- start=word.start,
- speaker=word.speaker,
- )
- else:
- current_segment.text += word.text
- last_word = word
+ current_segment = None
if current_segment:
segments.append(current_segment)
return segments
From 21e408b32391f1cefd69401e40ba558e421cdef8 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 20 Oct 2023 16:07:12 +0200
Subject: [PATCH 04/41] server: include transcripts words in database, but keep
back compatible api
---
server/reflector/views/transcripts.py | 80 +++++++++++++++++++++++----
1 file changed, 69 insertions(+), 11 deletions(-)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 0a068c17..1d9fd4bd 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -17,6 +17,8 @@ from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
+from reflector.processors.types import Transcript as ProcessorTranscript
+from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
@@ -60,7 +62,8 @@ class TranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
- segments: list[TranscriptSegmentTopic] = []
+ text: str | None = None
+ words: list[ProcessorWord] = []
class TranscriptFinalShortSummary(BaseModel):
@@ -304,6 +307,53 @@ async def transcripts_create(
# ==============================================================
+class GetTranscriptSegmentTopic(BaseModel):
+ text: str
+ start: float
+ speaker: int
+
+
+class GetTranscriptTopic(BaseModel):
+ title: str
+ summary: str
+ timestamp: float
+ text: str
+ segments: list[GetTranscriptSegmentTopic] = []
+
+ @classmethod
+ def from_transcript_topic(cls, topic: TranscriptTopic):
+ if not topic.words:
+ # In previous version, words were missing
+ # Just output a segment with speaker 0
+ text = topic.text
+ segments = [
+ GetTranscriptSegmentTopic(
+ text=topic.text,
+ start=topic.timestamp,
+ speaker=0,
+ )
+ ]
+ else:
+ # New versions include words
+ transcript = ProcessorTranscript(words=topic.words)
+ text = transcript.text
+ segments = [
+ GetTranscriptSegmentTopic(
+ text=segment.text,
+ start=segment.start,
+ speaker=segment.speaker,
+ )
+ for segment in transcript.as_segments()
+ ]
+ return cls(
+ title=topic.title,
+ summary=topic.summary,
+ timestamp=topic.timestamp,
+ text=text,
+ segments=segments,
+ )
+
+
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
@@ -412,7 +462,10 @@ async def transcript_get_audio_waveform(
return transcript.audio_waveform
-@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
+@router.get(
+ "/transcripts/{transcript_id}/topics",
+ response_model=list[GetTranscriptTopic],
+)
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -421,7 +474,11 @@ async def transcript_get_topics(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
- return transcript.topics
+
+ # convert to GetTranscriptTopic
+ return [
+ GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
+ ]
@router.get("/transcripts/{transcript_id}/events")
@@ -498,6 +555,13 @@ async def transcript_events_websocket(
async def handle_rtc_event(event: PipelineEvent, args, data):
+ try:
+ return await handle_rtc_event_once(event, args, data)
+ except Exception:
+ logger.exception("Error handling RTC event")
+
+
+async def handle_rtc_event_once(event: PipelineEvent, args, data):
# OFC the current implementation is not good,
# but it's just a POC before persistence. It won't query the
# transcript from the database for each event.
@@ -530,14 +594,8 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
- segments=[
- TranscriptSegmentTopic(
- speaker=segment.speaker,
- text=segment.text,
- timestamp=segment.start,
- )
- for segment in data.transcript.as_segments()
- ],
+ text=data.transcript.text,
+ words=data.transcript.words,
)
resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic)
From 01d7add6cc32d941b9e4e4e4a71616ad86e46531 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 20 Oct 2023 16:07:29 +0200
Subject: [PATCH 05/41] www: update openapi and display
---
www/app/[domain]/transcripts/topicList.tsx | 29 +++--
.../[domain]/transcripts/webSocketTypes.ts | 1 +
www/app/api/.openapi-generator/FILES | 4 +-
.../api/models/GetTranscriptSegmentTopic.ts | 88 +++++++++++++++
www/app/api/models/GetTranscriptTopic.ts | 103 ++++++++++++++++++
www/app/api/models/index.ts | 4 +-
6 files changed, 216 insertions(+), 13 deletions(-)
create mode 100644 www/app/api/models/GetTranscriptSegmentTopic.ts
create mode 100644 www/app/api/models/GetTranscriptTopic.ts
diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx
index 4fedacb0..d10cb13f 100644
--- a/www/app/[domain]/transcripts/topicList.tsx
+++ b/www/app/[domain]/transcripts/topicList.tsx
@@ -104,15 +104,26 @@ export function TopicList({
{activeTopic?.id == topic.id && (
- {topic.segments.map((segment, index) => (
-
- [{formatTime(segment.timestamp)}] Speaker{" "}
- {segment.speaker}: {segment.text}
-
- ))}
+ {topic.segments ? (
+ <>
+ {topic.segments.map((segment, index: number) => (
+
+
+ [{formatTime(segment.start)}]
+
+
+ Speaker {segment.speaker}
+
+ {segment.text}
+
+ ))}
+ >
+ ) : (
+ <>{topic.text}>
+ )}
)}
diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts
index 3e02e26a..abc67b33 100644
--- a/www/app/[domain]/transcripts/webSocketTypes.ts
+++ b/www/app/[domain]/transcripts/webSocketTypes.ts
@@ -9,6 +9,7 @@ export type Topic = {
title: string;
summary: string;
id: string;
+ text: string;
segments: SegmentTopic[];
};
diff --git a/www/app/api/.openapi-generator/FILES b/www/app/api/.openapi-generator/FILES
index 9192cb1f..532a6a16 100644
--- a/www/app/api/.openapi-generator/FILES
+++ b/www/app/api/.openapi-generator/FILES
@@ -5,11 +5,11 @@ models/AudioWaveform.ts
models/CreateTranscript.ts
models/DeletionStatus.ts
models/GetTranscript.ts
+models/GetTranscriptSegmentTopic.ts
+models/GetTranscriptTopic.ts
models/HTTPValidationError.ts
models/PageGetTranscript.ts
models/RtcOffer.ts
-models/TranscriptSegmentTopic.ts
-models/TranscriptTopic.ts
models/UpdateTranscript.ts
models/UserInfo.ts
models/ValidationError.ts
diff --git a/www/app/api/models/GetTranscriptSegmentTopic.ts b/www/app/api/models/GetTranscriptSegmentTopic.ts
new file mode 100644
index 00000000..cc2049b1
--- /dev/null
+++ b/www/app/api/models/GetTranscriptSegmentTopic.ts
@@ -0,0 +1,88 @@
+/* tslint:disable */
+/* eslint-disable */
+/**
+ * FastAPI
+ * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
+ *
+ * The version of the OpenAPI document: 0.1.0
+ *
+ *
+ * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
+ * https://openapi-generator.tech
+ * Do not edit the class manually.
+ */
+
+import { exists, mapValues } from "../runtime";
+/**
+ *
+ * @export
+ * @interface GetTranscriptSegmentTopic
+ */
+export interface GetTranscriptSegmentTopic {
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptSegmentTopic
+ */
+ text: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptSegmentTopic
+ */
+ start: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptSegmentTopic
+ */
+ speaker: any | null;
+}
+
+/**
+ * Check if a given object implements the GetTranscriptSegmentTopic interface.
+ */
+export function instanceOfGetTranscriptSegmentTopic(value: object): boolean {
+ let isInstance = true;
+ isInstance = isInstance && "text" in value;
+ isInstance = isInstance && "start" in value;
+ isInstance = isInstance && "speaker" in value;
+
+ return isInstance;
+}
+
+export function GetTranscriptSegmentTopicFromJSON(
+ json: any,
+): GetTranscriptSegmentTopic {
+ return GetTranscriptSegmentTopicFromJSONTyped(json, false);
+}
+
+export function GetTranscriptSegmentTopicFromJSONTyped(
+ json: any,
+ ignoreDiscriminator: boolean,
+): GetTranscriptSegmentTopic {
+ if (json === undefined || json === null) {
+ return json;
+ }
+ return {
+ text: json["text"],
+ start: json["start"],
+ speaker: json["speaker"],
+ };
+}
+
+export function GetTranscriptSegmentTopicToJSON(
+ value?: GetTranscriptSegmentTopic | null,
+): any {
+ if (value === undefined) {
+ return undefined;
+ }
+ if (value === null) {
+ return null;
+ }
+ return {
+ text: value.text,
+ start: value.start,
+ speaker: value.speaker,
+ };
+}
diff --git a/www/app/api/models/GetTranscriptTopic.ts b/www/app/api/models/GetTranscriptTopic.ts
new file mode 100644
index 00000000..7a7d4c90
--- /dev/null
+++ b/www/app/api/models/GetTranscriptTopic.ts
@@ -0,0 +1,103 @@
+/* tslint:disable */
+/* eslint-disable */
+/**
+ * FastAPI
+ * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
+ *
+ * The version of the OpenAPI document: 0.1.0
+ *
+ *
+ * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
+ * https://openapi-generator.tech
+ * Do not edit the class manually.
+ */
+
+import { exists, mapValues } from "../runtime";
+/**
+ *
+ * @export
+ * @interface GetTranscriptTopic
+ */
+export interface GetTranscriptTopic {
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptTopic
+ */
+ title: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptTopic
+ */
+ summary: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptTopic
+ */
+ timestamp: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptTopic
+ */
+ text: any | null;
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptTopic
+ */
+ segments?: any | null;
+}
+
+/**
+ * Check if a given object implements the GetTranscriptTopic interface.
+ */
+export function instanceOfGetTranscriptTopic(value: object): boolean {
+ let isInstance = true;
+ isInstance = isInstance && "title" in value;
+ isInstance = isInstance && "summary" in value;
+ isInstance = isInstance && "timestamp" in value;
+ isInstance = isInstance && "text" in value;
+
+ return isInstance;
+}
+
+export function GetTranscriptTopicFromJSON(json: any): GetTranscriptTopic {
+ return GetTranscriptTopicFromJSONTyped(json, false);
+}
+
+export function GetTranscriptTopicFromJSONTyped(
+ json: any,
+ ignoreDiscriminator: boolean,
+): GetTranscriptTopic {
+ if (json === undefined || json === null) {
+ return json;
+ }
+ return {
+ title: json["title"],
+ summary: json["summary"],
+ timestamp: json["timestamp"],
+ text: json["text"],
+ segments: !exists(json, "segments") ? undefined : json["segments"],
+ };
+}
+
+export function GetTranscriptTopicToJSON(
+ value?: GetTranscriptTopic | null,
+): any {
+ if (value === undefined) {
+ return undefined;
+ }
+ if (value === null) {
+ return null;
+ }
+ return {
+ title: value.title,
+ summary: value.summary,
+ timestamp: value.timestamp,
+ text: value.text,
+ segments: value.segments,
+ };
+}
diff --git a/www/app/api/models/index.ts b/www/app/api/models/index.ts
index 6a7d7318..9e691b09 100644
--- a/www/app/api/models/index.ts
+++ b/www/app/api/models/index.ts
@@ -4,11 +4,11 @@ export * from "./AudioWaveform";
export * from "./CreateTranscript";
export * from "./DeletionStatus";
export * from "./GetTranscript";
+export * from "./GetTranscriptSegmentTopic";
+export * from "./GetTranscriptTopic";
export * from "./HTTPValidationError";
export * from "./PageGetTranscript";
export * from "./RtcOffer";
-export * from "./TranscriptSegmentTopic";
-export * from "./TranscriptTopic";
export * from "./UpdateTranscript";
export * from "./UserInfo";
export * from "./ValidationError";
From f4cffc0e66c4d42aa69644b1456a0b63ca3ff8e7 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 20 Oct 2023 16:14:30 +0200
Subject: [PATCH 06/41] server: add tests on segmentation and fix issue with
speaker
---
server/reflector/processors/types.py | 20 ++++++++++++++++---
.../test_processor_transcript_segment.py | 19 ++++++++++++++++--
2 files changed, 34 insertions(+), 5 deletions(-)
diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py
index ba0cccf9..d2c32d17 100644
--- a/server/reflector/processors/types.py
+++ b/server/reflector/processors/types.py
@@ -125,16 +125,30 @@ class Transcript(BaseModel):
speaker=word.speaker,
)
continue
+
+ # If the word is attach to another speaker, push the current segment
+ # and start a new one
+ if word.speaker != current_segment.speaker:
+ segments.append(current_segment)
+ current_segment = TranscriptSegment(
+ text=word.text,
+ start=word.start,
+ speaker=word.speaker,
+ )
+ continue
+
+ # if the word is the end of a sentence, and we have enough content,
+ # add the word to the current segment and push it
current_segment.text += word.text
have_punc = PUNC_RE.search(word.text)
- if word.speaker != current_segment.speaker or (
- have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH)
- ):
+ if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
+
if current_segment:
segments.append(current_segment)
+
return segments
diff --git a/server/tests/test_processor_transcript_segment.py b/server/tests/test_processor_transcript_segment.py
index 3bb6182f..6fde0dd1 100644
--- a/server/tests/test_processor_transcript_segment.py
+++ b/server/tests/test_processor_transcript_segment.py
@@ -142,5 +142,20 @@ def test_processor_transcript_segment():
]
)
- for segment in transcript.as_segments():
- print(segment)
+ segments = transcript.as_segments()
+ assert len(segments) == 7
+
+ # check speaker order
+ assert segments[0].speaker == 0
+ assert segments[1].speaker == 0
+ assert segments[2].speaker == 0
+ assert segments[3].speaker == 1
+ assert segments[4].speaker == 2
+ assert segments[5].speaker == 0
+ assert segments[6].speaker == 0
+
+ # check the timing (first entry, and first of others speakers)
+ assert segments[0].start == 5.12
+ assert segments[3].start == 30.72
+ assert segments[4].start == 31.56
+ assert segments[5].start == 32.38
From 5f00673ebddd91a2105d44871b597de7d941a988 Mon Sep 17 00:00:00 2001
From: Koper
Date: Wed, 25 Oct 2023 15:13:52 +0100
Subject: [PATCH 07/41] Implemented speaker color functions
---
www/app/[domain]/transcripts/topicList.tsx | 16 +-
www/app/[domain]/transcripts/useWebSockets.ts | 167 ++++++++++++++++--
www/app/lib/utils.ts | 122 +++++++++++++
3 files changed, 286 insertions(+), 19 deletions(-)
diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx
index d10cb13f..ef4a2889 100644
--- a/www/app/[domain]/transcripts/topicList.tsx
+++ b/www/app/[domain]/transcripts/topicList.tsx
@@ -7,6 +7,7 @@ import {
import { formatTime } from "../../lib/time";
import ScrollToBottom from "./scrollToBottom";
import { Topic } from "./webSocketTypes";
+import { generateHighContrastColor } from "../lib/utils";
type TopicListProps = {
topics: Topic[];
@@ -114,9 +115,18 @@ export function TopicList({
[{formatTime(segment.start)}]
-
- Speaker {segment.speaker}
-
+
+ {" "}
+ (Speaker {segment.speaker}):
+ {" "}
{segment.text}
))}
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index 6bd7bf48..8196749e 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -56,38 +56,116 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
timestamp: 10,
summary: "This is test topic 1",
title: "Topic 1: Introduction to Quantum Mechanics",
- transcript:
- "A brief overview of quantum mechanics and its principles.",
+ text: "A brief overview of quantum mechanics and its principles.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ,
+ {
+ speaker: 3,
+ start: 90,
+ text: "This is the third speaker",
+ },
+ {
+ speaker: 4,
+ start: 90,
+ text: "This is the fourth speaker",
+ },
+ {
+ speaker: 5,
+ start: 123,
+ text: "This is the fifth speaker",
+ },
+ {
+ speaker: 6,
+ start: 300,
+ text: "This is the sixth speaker",
+ },
+ ],
},
{
id: "2",
timestamp: 20,
summary: "This is test topic 2",
title: "Topic 2: Machine Learning Algorithms",
- transcript:
- "Understanding the different types of machine learning algorithms.",
+ text: "Understanding the different types of machine learning algorithms.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "3",
timestamp: 30,
summary: "This is test topic 3",
title: "Topic 3: Mental Health Awareness",
- transcript: "Ways to improve mental health and reduce stigma.",
+ text: "Ways to improve mental health and reduce stigma.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "4",
timestamp: 40,
summary: "This is test topic 4",
title: "Topic 4: Basics of Productivity",
- transcript: "Tips and tricks to increase daily productivity.",
+ text: "Tips and tricks to increase daily productivity.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "5",
timestamp: 50,
summary: "This is test topic 5",
title: "Topic 5: Future of Aviation",
- transcript:
- "Exploring the advancements and possibilities in aviation.",
+ text: "Exploring the advancements and possibilities in aviation.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
]);
@@ -104,8 +182,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 1",
title:
"Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.",
- transcript:
- "A brief overview of quantum mechanics and its principles.",
+ text: "A brief overview of quantum mechanics and its principles.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "2",
@@ -113,8 +202,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 2",
title:
"Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.",
- transcript:
- "Understanding the different types of machine learning algorithms.",
+ text: "Understanding the different types of machine learning algorithms.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "3",
@@ -122,7 +222,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 3",
title:
"Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.",
- transcript: "Ways to improve mental health and reduce stigma.",
+ text: "Ways to improve mental health and reduce stigma.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "4",
@@ -130,7 +242,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 4",
title:
"Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.",
- transcript: "Tips and tricks to increase daily productivity.",
+ text: "Tips and tricks to increase daily productivity.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
{
id: "5",
@@ -138,8 +262,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 5",
title:
"Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.",
- transcript:
- "Exploring the advancements and possibilities in aviation.",
+ text: "Exploring the advancements and possibilities in aviation.",
+ segments: [
+ {
+ speaker: 1,
+ start: 0,
+ text: "This is the transcription of an example title",
+ },
+ {
+ speaker: 2,
+ start: 10,
+ text: "This is the second speaker",
+ },
+ ],
},
]);
diff --git a/www/app/lib/utils.ts b/www/app/lib/utils.ts
index db775f07..43ed8bf8 100644
--- a/www/app/lib/utils.ts
+++ b/www/app/lib/utils.ts
@@ -1,3 +1,125 @@
export function isDevelopment() {
return process.env.NEXT_PUBLIC_ENV === "development";
}
+
+// Function to calculate WCAG contrast ratio
+export const getContrastRatio = (
+ foreground: [number, number, number],
+ background: [number, number, number],
+) => {
+ const [r1, g1, b1] = foreground;
+ const [r2, g2, b2] = background;
+
+ const lum1 =
+ 0.2126 * Math.pow(r1 / 255, 2.2) +
+ 0.7152 * Math.pow(g1 / 255, 2.2) +
+ 0.0722 * Math.pow(b1 / 255, 2.2);
+ const lum2 =
+ 0.2126 * Math.pow(r2 / 255, 2.2) +
+ 0.7152 * Math.pow(g2 / 255, 2.2) +
+ 0.0722 * Math.pow(b2 / 255, 2.2);
+
+ return (Math.max(lum1, lum2) + 0.05) / (Math.min(lum1, lum2) + 0.05);
+};
+
+// Function to hash string into 32-bit integer
+// 🔴 DO NOT USE FOR CRYPTOGRAPHY PURPOSES 🔴
+
+export function murmurhash3_32_gc(key: string, seed: number = 0) {
+ let remainder, bytes, h1, h1b, c1, c2, k1, i;
+
+ remainder = key.length & 3; // key.length % 4
+ bytes = key.length - remainder;
+ h1 = seed;
+ c1 = 0xcc9e2d51;
+ c2 = 0x1b873593;
+ i = 0;
+
+ while (i < bytes) {
+ k1 =
+ (key.charCodeAt(i) & 0xff) |
+ ((key.charCodeAt(++i) & 0xff) << 8) |
+ ((key.charCodeAt(++i) & 0xff) << 16) |
+ ((key.charCodeAt(++i) & 0xff) << 24);
+
+ ++i;
+
+ k1 =
+ ((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & 0xffffffff;
+ k1 = (k1 << 15) | (k1 >>> 17);
+ k1 =
+ ((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & 0xffffffff;
+
+ h1 ^= k1;
+ h1 = (h1 << 13) | (h1 >>> 19);
+ h1b =
+ ((h1 & 0xffff) * 5 + ((((h1 >>> 16) * 5) & 0xffff) << 16)) & 0xffffffff;
+ h1 = (h1b & 0xffff) + 0x6b64 + ((((h1b >>> 16) + 0xe654) & 0xffff) << 16);
+ }
+
+ k1 = 0;
+
+ switch (remainder) {
+ case 3:
+ k1 ^= (key.charCodeAt(i + 2) & 0xff) << 16;
+ case 2:
+ k1 ^= (key.charCodeAt(i + 1) & 0xff) << 8;
+ case 1:
+ k1 ^= key.charCodeAt(i) & 0xff;
+
+ k1 =
+ ((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) &
+ 0xffffffff;
+ k1 = (k1 << 15) | (k1 >>> 17);
+ k1 =
+ ((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) &
+ 0xffffffff;
+ h1 ^= k1;
+ }
+
+ h1 ^= key.length;
+
+ h1 ^= h1 >>> 16;
+ h1 =
+ ((h1 & 0xffff) * 0x85ebca6b +
+ ((((h1 >>> 16) * 0x85ebca6b) & 0xffff) << 16)) &
+ 0xffffffff;
+ h1 ^= h1 >>> 13;
+ h1 =
+ ((h1 & 0xffff) * 0xc2b2ae35 +
+ ((((h1 >>> 16) * 0xc2b2ae35) & 0xffff) << 16)) &
+ 0xffffffff;
+ h1 ^= h1 >>> 16;
+
+ return h1 >>> 0;
+}
+
+// Generates a color that is guaranteed to have high contrast with the given background color (optional)
+
+export const generateHighContrastColor = (
+ name: string,
+ backgroundColor: [number, number, number] | null = null,
+) => {
+ const hash = murmurhash3_32_gc(name);
+ console.log(name, hash);
+
+ let red = (hash & 0xff0000) >> 16;
+ let green = (hash & 0x00ff00) >> 8;
+ let blue = hash & 0x0000ff;
+
+ const getCssColor = (red: number, green: number, blue: number) =>
+ `rgb(${red}, ${green}, ${blue})`;
+
+ if (!backgroundColor) return getCssColor(red, green, blue);
+
+ const contrast = getContrastRatio([red, green, blue], backgroundColor);
+
+ // Adjust the color to achieve better contrast if necessary (WCAG recommends at least 4.5:1 for text)
+ if (contrast < 4.5) {
+ red = Math.abs(255 - red);
+ green = Math.abs(255 - green);
+ blue = Math.abs(255 - blue);
+ }
+
+ return getCssColor(red, green, blue);
+};
From 16a857927286cf82c9c766d4cec3d5ef9d4fdd5a Mon Sep 17 00:00:00 2001
From: Koper
Date: Wed, 25 Oct 2023 15:17:34 +0100
Subject: [PATCH 08/41] Removed console.log
---
www/app/lib/utils.ts | 2 --
1 file changed, 2 deletions(-)
diff --git a/www/app/lib/utils.ts b/www/app/lib/utils.ts
index 43ed8bf8..6b72ddea 100644
--- a/www/app/lib/utils.ts
+++ b/www/app/lib/utils.ts
@@ -101,8 +101,6 @@ export const generateHighContrastColor = (
backgroundColor: [number, number, number] | null = null,
) => {
const hash = murmurhash3_32_gc(name);
- console.log(name, hash);
-
let red = (hash & 0xff0000) >> 16;
let green = (hash & 0x00ff00) >> 8;
let blue = hash & 0x0000ff;
From 8bebb2a76907bb9095bd2d3d244e7f92afd8e38d Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 20 Oct 2023 18:00:59 +0200
Subject: [PATCH 09/41] server: start moving to an external celery task
---
server/poetry.lock | 210 +++++++++++++++++++++-
server/pyproject.toml | 1 +
server/reflector/settings.py | 4 +
server/reflector/tasks/boot.py | 2 +
server/reflector/tasks/post_transcript.py | 170 ++++++++++++++++++
server/reflector/tasks/worker.py | 6 +
6 files changed, 392 insertions(+), 1 deletion(-)
create mode 100644 server/reflector/tasks/boot.py
create mode 100644 server/reflector/tasks/post_transcript.py
create mode 100644 server/reflector/tasks/worker.py
diff --git a/server/poetry.lock b/server/poetry.lock
index 330c23e3..0df46097 100644
--- a/server/poetry.lock
+++ b/server/poetry.lock
@@ -308,6 +308,20 @@ typing-extensions = ">=4"
[package.extras]
tz = ["python-dateutil"]
+[[package]]
+name = "amqp"
+version = "5.1.1"
+description = "Low-level AMQP client for Python (fork of amqplib)."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"},
+ {file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"},
+]
+
+[package.dependencies]
+vine = ">=5.0.0"
+
[[package]]
name = "annotated-types"
version = "0.6.0"
@@ -474,6 +488,17 @@ files = [
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
]
+[[package]]
+name = "billiard"
+version = "4.1.0"
+description = "Python multiprocessing fork with improvements and bugfixes"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "billiard-4.1.0-py3-none-any.whl", hash = "sha256:0f50d6be051c6b2b75bfbc8bfd85af195c5739c281d3f5b86a5640c65563614a"},
+ {file = "billiard-4.1.0.tar.gz", hash = "sha256:1ad2eeae8e28053d729ba3373d34d9d6e210f6e4d8bf0a9c64f92bd053f1edf5"},
+]
+
[[package]]
name = "black"
version = "23.9.1"
@@ -556,6 +581,61 @@ urllib3 = ">=1.25.4,<1.27"
[package.extras]
crt = ["awscrt (==0.16.26)"]
+[[package]]
+name = "celery"
+version = "5.3.4"
+description = "Distributed Task Queue."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "celery-5.3.4-py3-none-any.whl", hash = "sha256:1e6ed40af72695464ce98ca2c201ad0ef8fd192246f6c9eac8bba343b980ad34"},
+ {file = "celery-5.3.4.tar.gz", hash = "sha256:9023df6a8962da79eb30c0c84d5f4863d9793a466354cc931d7f72423996de28"},
+]
+
+[package.dependencies]
+billiard = ">=4.1.0,<5.0"
+click = ">=8.1.2,<9.0"
+click-didyoumean = ">=0.3.0"
+click-plugins = ">=1.1.1"
+click-repl = ">=0.2.0"
+kombu = ">=5.3.2,<6.0"
+python-dateutil = ">=2.8.2"
+tzdata = ">=2022.7"
+vine = ">=5.0.0,<6.0"
+
+[package.extras]
+arangodb = ["pyArango (>=2.0.2)"]
+auth = ["cryptography (==41.0.3)"]
+azureblockblob = ["azure-storage-blob (>=12.15.0)"]
+brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
+cassandra = ["cassandra-driver (>=3.25.0,<4)"]
+consul = ["python-consul2 (==0.1.5)"]
+cosmosdbsql = ["pydocumentdb (==2.3.5)"]
+couchbase = ["couchbase (>=3.0.0)"]
+couchdb = ["pycouchdb (==1.14.2)"]
+django = ["Django (>=2.2.28)"]
+dynamodb = ["boto3 (>=1.26.143)"]
+elasticsearch = ["elasticsearch (<8.0)"]
+eventlet = ["eventlet (>=0.32.0)"]
+gevent = ["gevent (>=1.5.0)"]
+librabbitmq = ["librabbitmq (>=2.0.0)"]
+memcache = ["pylibmc (==1.6.3)"]
+mongodb = ["pymongo[srv] (>=4.0.2)"]
+msgpack = ["msgpack (==1.0.5)"]
+pymemcache = ["python-memcached (==1.59)"]
+pyro = ["pyro4 (==4.82)"]
+pytest = ["pytest-celery (==0.0.0)"]
+redis = ["redis (>=4.5.2,!=4.5.5,<5.0.0)"]
+s3 = ["boto3 (>=1.26.143)"]
+slmq = ["softlayer-messaging (>=1.0.3)"]
+solar = ["ephem (==4.1.4)"]
+sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
+sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
+tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
+yaml = ["PyYAML (>=3.10)"]
+zookeeper = ["kazoo (>=1.3.1)"]
+zstd = ["zstandard (==0.21.0)"]
+
[[package]]
name = "certifi"
version = "2023.7.22"
@@ -744,6 +824,55 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
+[[package]]
+name = "click-didyoumean"
+version = "0.3.0"
+description = "Enables git-like *did-you-mean* feature in click"
+optional = false
+python-versions = ">=3.6.2,<4.0.0"
+files = [
+ {file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"},
+ {file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"},
+]
+
+[package.dependencies]
+click = ">=7"
+
+[[package]]
+name = "click-plugins"
+version = "1.1.1"
+description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
+optional = false
+python-versions = "*"
+files = [
+ {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
+ {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
+]
+
+[package.dependencies]
+click = ">=4.0"
+
+[package.extras]
+dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
+
+[[package]]
+name = "click-repl"
+version = "0.3.0"
+description = "REPL plugin for Click"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"},
+ {file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"},
+]
+
+[package.dependencies]
+click = ">=7.0"
+prompt-toolkit = ">=3.0.36"
+
+[package.extras]
+testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"]
+
[[package]]
name = "colorama"
version = "0.4.6"
@@ -1624,6 +1753,38 @@ files = [
cryptography = ">=3.4"
deprecated = "*"
+[[package]]
+name = "kombu"
+version = "5.3.2"
+description = "Messaging library for Python."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "kombu-5.3.2-py3-none-any.whl", hash = "sha256:b753c9cfc9b1e976e637a7cbc1a65d446a22e45546cd996ea28f932082b7dc9e"},
+ {file = "kombu-5.3.2.tar.gz", hash = "sha256:0ba213f630a2cb2772728aef56ac6883dc3a2f13435e10048f6e97d48506dbbd"},
+]
+
+[package.dependencies]
+amqp = ">=5.1.1,<6.0.0"
+vine = "*"
+
+[package.extras]
+azureservicebus = ["azure-servicebus (>=7.10.0)"]
+azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"]
+confluentkafka = ["confluent-kafka (==2.1.1)"]
+consul = ["python-consul2"]
+librabbitmq = ["librabbitmq (>=2.0.0)"]
+mongodb = ["pymongo (>=4.1.1)"]
+msgpack = ["msgpack"]
+pyro = ["pyro4"]
+qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"]
+redis = ["redis (>=4.5.2)"]
+slmq = ["softlayer-messaging (>=1.0.3)"]
+sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
+sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
+yaml = ["PyYAML (>=3.10)"]
+zookeeper = ["kazoo (>=2.8.0)"]
+
[[package]]
name = "levenshtein"
version = "0.21.1"
@@ -2151,6 +2312,20 @@ files = [
fastapi = ">=0.38.1,<1.0.0"
prometheus-client = ">=0.8.0,<1.0.0"
+[[package]]
+name = "prompt-toolkit"
+version = "3.0.39"
+description = "Library for building powerful interactive command lines in Python"
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
+ {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
+]
+
+[package.dependencies]
+wcwidth = "*"
+
[[package]]
name = "protobuf"
version = "4.24.4"
@@ -3438,6 +3613,17 @@ files = [
{file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
]
+[[package]]
+name = "tzdata"
+version = "2023.3"
+description = "Provider of IANA time zone data"
+optional = false
+python-versions = ">=2"
+files = [
+ {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
+ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
+]
+
[[package]]
name = "urllib3"
version = "1.26.17"
@@ -3523,6 +3709,17 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak
docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"]
+[[package]]
+name = "vine"
+version = "5.0.0"
+description = "Promises, promises, promises."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"},
+ {file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"},
+]
+
[[package]]
name = "watchfiles"
version = "0.20.0"
@@ -3557,6 +3754,17 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
+[[package]]
+name = "wcwidth"
+version = "0.2.8"
+description = "Measures the displayed width of unicode strings in a terminal"
+optional = false
+python-versions = "*"
+files = [
+ {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"},
+ {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"},
+]
+
[[package]]
name = "websockets"
version = "11.0.3"
@@ -3838,4 +4046,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67"
+content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c"
diff --git a/server/pyproject.toml b/server/pyproject.toml
index e3b44774..7b1b7936 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -33,6 +33,7 @@ prometheus-fastapi-instrumentator = "^6.1.0"
sentencepiece = "^0.1.99"
protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
+celery = "^5.3.4"
[tool.poetry.group.dev.dependencies]
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index e0ffd826..1503948a 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -113,5 +113,9 @@ class Settings(BaseSettings):
# Min transcript length to generate topic + summary
MIN_TRANSCRIPT_LENGTH: int = 750
+ # Celery
+ CELERY_BROKER_URL: str = "redis://localhost:6379/1"
+ CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
+
settings = Settings()
diff --git a/server/reflector/tasks/boot.py b/server/reflector/tasks/boot.py
new file mode 100644
index 00000000..88cc2d6f
--- /dev/null
+++ b/server/reflector/tasks/boot.py
@@ -0,0 +1,2 @@
+import reflector.tasks.post_transcript # noqa
+import reflector.tasks.worker # noqa
diff --git a/server/reflector/tasks/post_transcript.py b/server/reflector/tasks/post_transcript.py
new file mode 100644
index 00000000..1cfbc664
--- /dev/null
+++ b/server/reflector/tasks/post_transcript.py
@@ -0,0 +1,170 @@
+from reflector.logger import logger
+from reflector.processors import (
+ Pipeline,
+ Processor,
+ TranscriptFinalLongSummaryProcessor,
+ TranscriptFinalShortSummaryProcessor,
+ TranscriptFinalTitleProcessor,
+)
+from reflector.processors.base import BroadcastProcessor
+from reflector.processors.types import (
+ FinalLongSummary,
+ FinalShortSummary,
+ FinalTitle,
+ TitleSummary,
+)
+from reflector.processors.types import Transcript as ProcessorTranscript
+from reflector.tasks.worker import celery
+from reflector.views.rtc_offer import PipelineEvent, TranscriptionContext
+from reflector.views.transcripts import Transcript, transcripts_controller
+
+
+class TranscriptAudioDiarizationProcessor(Processor):
+ INPUT_TYPE = Transcript
+ OUTPUT_TYPE = TitleSummary
+
+ async def _push(self, data: Transcript):
+ # Gather diarization data
+ diarization = [
+ {"start": 0.0, "stop": 4.9, "speaker": 2},
+ {"start": 5.6, "stop": 6.7, "speaker": 2},
+ {"start": 7.3, "stop": 8.9, "speaker": 2},
+ {"start": 7.3, "stop": 7.9, "speaker": 0},
+ {"start": 9.4, "stop": 11.2, "speaker": 2},
+ {"start": 9.7, "stop": 10.0, "speaker": 0},
+ {"start": 10.0, "stop": 10.1, "speaker": 0},
+ {"start": 11.7, "stop": 16.1, "speaker": 2},
+ {"start": 11.8, "stop": 12.1, "speaker": 1},
+ {"start": 16.4, "stop": 21.0, "speaker": 2},
+ {"start": 21.1, "stop": 22.6, "speaker": 2},
+ {"start": 24.7, "stop": 31.9, "speaker": 2},
+ {"start": 32.0, "stop": 32.8, "speaker": 1},
+ {"start": 33.4, "stop": 37.8, "speaker": 2},
+ {"start": 37.9, "stop": 40.3, "speaker": 0},
+ {"start": 39.2, "stop": 40.4, "speaker": 2},
+ {"start": 40.7, "stop": 41.4, "speaker": 0},
+ {"start": 41.6, "stop": 45.7, "speaker": 2},
+ {"start": 46.4, "stop": 53.1, "speaker": 2},
+ {"start": 53.6, "stop": 56.5, "speaker": 2},
+ {"start": 54.9, "stop": 75.4, "speaker": 1},
+ {"start": 57.3, "stop": 58.0, "speaker": 2},
+ {"start": 65.7, "stop": 66.0, "speaker": 2},
+ {"start": 75.8, "stop": 78.8, "speaker": 1},
+ {"start": 79.0, "stop": 82.6, "speaker": 1},
+ {"start": 83.2, "stop": 83.3, "speaker": 1},
+ {"start": 84.5, "stop": 94.3, "speaker": 1},
+ {"start": 95.1, "stop": 100.7, "speaker": 1},
+ {"start": 100.7, "stop": 102.0, "speaker": 0},
+ {"start": 100.7, "stop": 101.8, "speaker": 1},
+ {"start": 102.0, "stop": 103.0, "speaker": 1},
+ {"start": 103.0, "stop": 103.7, "speaker": 0},
+ {"start": 103.7, "stop": 103.8, "speaker": 1},
+ {"start": 103.8, "stop": 113.9, "speaker": 0},
+ {"start": 114.7, "stop": 117.0, "speaker": 0},
+ {"start": 117.0, "stop": 117.4, "speaker": 1},
+ ]
+
+ # now reapply speaker to topics (if any)
+ # topics is a list[BaseModel] with an attribute words
+ # words is a list[BaseModel] with text, start and speaker attribute
+
+ # mutate in place
+ for topic in data.topics:
+ for word in topic.words:
+ for d in diarization:
+ if d["start"] <= word.start <= d["stop"]:
+ word.speaker = d["speaker"]
+
+ topics = data.topics[:]
+
+ await transcripts_controller.update(
+ data,
+ {
+ "topics": [topic.model_dump(mode="json") for topic in data.topics],
+ },
+ )
+
+ # emit them
+ for topic in topics:
+ transcript = ProcessorTranscript(words=topic.words)
+ await self.emit(
+ TitleSummary(
+ title=topic.title,
+ summary=topic.summary,
+ timestamp=topic.timestamp,
+ duration=0,
+ transcript=transcript,
+ )
+ )
+
+
+@celery.task(name="post_transcript")
+async def post_transcript_pipeline(transcript_id: str):
+ # get transcript
+ transcript = await transcripts_controller.get_by_id(transcript_id)
+ if not transcript:
+ logger.error("Transcript not found", transcript_id=transcript_id)
+ return
+
+ ctx = TranscriptionContext(logger=logger.bind(transcript_id=transcript_id))
+ event_callback = None
+ event_callback_args = None
+
+ async def on_final_short_summary(summary: FinalShortSummary):
+ ctx.logger.info("FinalShortSummary", final_short_summary=summary)
+
+ # send to callback (eg. websocket)
+ if event_callback:
+ await event_callback(
+ event=PipelineEvent.FINAL_SHORT_SUMMARY,
+ args=event_callback_args,
+ data=summary,
+ )
+
+ async def on_final_long_summary(summary: FinalLongSummary):
+ ctx.logger.info("FinalLongSummary", final_summary=summary)
+
+ # send to callback (eg. websocket)
+ if event_callback:
+ await event_callback(
+ event=PipelineEvent.FINAL_LONG_SUMMARY,
+ args=event_callback_args,
+ data=summary,
+ )
+
+ async def on_final_title(title: FinalTitle):
+ ctx.logger.info("FinalTitle", final_title=title)
+
+ # send to callback (eg. websocket)
+ if event_callback:
+ await event_callback(
+ event=PipelineEvent.FINAL_TITLE,
+ args=event_callback_args,
+ data=title,
+ )
+
+ ctx.logger.info("Starting pipeline (diarization)")
+ ctx.pipeline = Pipeline(
+ TranscriptAudioDiarizationProcessor(),
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalTitleProcessor.as_threaded(),
+ TranscriptFinalLongSummaryProcessor.as_threaded(),
+ TranscriptFinalShortSummaryProcessor.as_threaded(),
+ ]
+ ),
+ )
+
+ await ctx.pipeline.push(transcript)
+ await ctx.pipeline.flush()
+
+
+if __name__ == "__main__":
+ import argparse
+ import asyncio
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("transcript_id", type=str)
+ args = parser.parse_args()
+
+ asyncio.run(post_transcript_pipeline(args.transcript_id))
diff --git a/server/reflector/tasks/worker.py b/server/reflector/tasks/worker.py
new file mode 100644
index 00000000..4379a1b7
--- /dev/null
+++ b/server/reflector/tasks/worker.py
@@ -0,0 +1,6 @@
+from celery import Celery
+from reflector.settings import settings
+
+celery = Celery(__name__)
+celery.conf.broker_url = settings.CELERY_BROKER_URL
+celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
From 00c06b7971e935f4c681a4e737c9815f6aff9133 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 25 Oct 2023 16:56:43 +0200
Subject: [PATCH 10/41] server: use redis pubsub for interprocess websocket
communication
---
server/docker-compose.yml | 22 +++--
server/poetry.lock | 20 +++-
server/pyproject.toml | 1 +
server/reflector/settings.py | 4 +
server/reflector/views/transcripts.py | 61 ++++---------
server/reflector/ws_manager.py | 127 ++++++++++++++++++++++++++
6 files changed, 183 insertions(+), 52 deletions(-)
create mode 100644 server/reflector/ws_manager.py
diff --git a/server/docker-compose.yml b/server/docker-compose.yml
index 374130fa..4e5a21e8 100644
--- a/server/docker-compose.yml
+++ b/server/docker-compose.yml
@@ -1,15 +1,19 @@
version: "3.9"
services:
- server:
- build:
- context: .
+ # server:
+ # build:
+ # context: .
+ # ports:
+ # - 1250:1250
+ # environment:
+ # LLM_URL: "${LLM_URL}"
+ # MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
+ # volumes:
+ # - model-cache:/root/.cache
+ redis:
+ image: redis:7.2
ports:
- - 1250:1250
- environment:
- LLM_URL: "${LLM_URL}"
- MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
- volumes:
- - model-cache:/root/.cache
+ - 6379:6379
volumes:
model-cache:
diff --git a/server/poetry.lock b/server/poetry.lock
index 0df46097..35d98382 100644
--- a/server/poetry.lock
+++ b/server/poetry.lock
@@ -2919,6 +2919,24 @@ files = [
[package.extras]
full = ["numpy"]
+[[package]]
+name = "redis"
+version = "5.0.1"
+description = "Python client for Redis database and key-value store"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"},
+ {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"},
+]
+
+[package.dependencies]
+async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
+
+[package.extras]
+hiredis = ["hiredis (>=1.0.0)"]
+ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
+
[[package]]
name = "regex"
version = "2023.10.3"
@@ -4046,4 +4064,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c"
+content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63"
diff --git a/server/pyproject.toml b/server/pyproject.toml
index 7b1b7936..ed231a4f 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -34,6 +34,7 @@ sentencepiece = "^0.1.99"
protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
celery = "^5.3.4"
+redis = "^5.0.1"
[tool.poetry.group.dev.dependencies]
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index 1503948a..d7cc2c33 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -117,5 +117,9 @@ class Settings(BaseSettings):
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
+ # Redis
+ REDIS_HOST: str = "localhost"
+ REDIS_PORT: int = 6379
+
settings = Settings()
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 1d9fd4bd..9480461f 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -21,12 +21,14 @@ from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
+from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
+ws_manager = get_ws_manager()
# ==============================================================
# Models to move to a database, but required for the API to work
@@ -487,40 +489,10 @@ async def transcript_get_websocket_events(transcript_id: str):
# ==============================================================
-# Websocket Manager
+# Websocket
# ==============================================================
-class WebsocketManager:
- def __init__(self):
- self.active_connections = {}
-
- async def connect(self, transcript_id: str, websocket: WebSocket):
- await websocket.accept()
- if transcript_id not in self.active_connections:
- self.active_connections[transcript_id] = []
- self.active_connections[transcript_id].append(websocket)
-
- def disconnect(self, transcript_id: str, websocket: WebSocket):
- if transcript_id not in self.active_connections:
- return
- self.active_connections[transcript_id].remove(websocket)
- if not self.active_connections[transcript_id]:
- del self.active_connections[transcript_id]
-
- async def send_json(self, transcript_id: str, message):
- if transcript_id not in self.active_connections:
- return
- for connection in self.active_connections[transcript_id][:]:
- try:
- await connection.send_json(message)
- except Exception:
- self.active_connections[transcript_id].remove(connection)
-
-
-ws_manager = WebsocketManager()
-
-
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
@@ -532,21 +504,25 @@ async def transcript_events_websocket(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
- await ws_manager.connect(transcript_id, websocket)
+ # connect to websocket manager
+ # use ts:transcript_id as room id
+ room_id = f"ts:{transcript_id}"
+ await ws_manager.add_user_to_room(room_id, websocket)
- # on first connection, send all events
- for event in transcript.events:
- await websocket.send_json(event.model_dump(mode="json"))
-
- # XXX if transcript is final (locked=True and status=ended)
- # XXX send a final event to the client and close the connection
-
- # endless loop to wait for new events
try:
+ # on first connection, send all events only to the current user
+ for event in transcript.events:
+ await websocket.send_json(event.model_dump(mode="json"))
+
+ # XXX if transcript is final (locked=True and status=ended)
+ # XXX send a final event to the client and close the connection
+
+ # endless loop to wait for new events
+ # we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
- ws_manager.disconnect(transcript_id, websocket)
+ await ws_manager.remove_user_from_room(room_id, websocket)
# ==============================================================
@@ -658,7 +634,8 @@ async def handle_rtc_event_once(event: PipelineEvent, args, data):
return
# transmit to websocket clients
- await ws_manager.send_json(transcript_id, resp.model_dump(mode="json"))
+ room_id = f"ts:{transcript_id}"
+ await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
@router.post("/transcripts/{transcript_id}/record/webrtc")
diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py
new file mode 100644
index 00000000..43475c1d
--- /dev/null
+++ b/server/reflector/ws_manager.py
@@ -0,0 +1,127 @@
+"""
+Websocket manager
+=================
+
+This module contains the WebsocketManager class, which is responsible for
+managing websockets and handling websocket connections.
+
+It uses the RedisPubSubManager class to subscribe to Redis channels and
+broadcast messages to all connected websockets.
+"""
+
+import asyncio
+import json
+
+import redis.asyncio as redis
+from fastapi import WebSocket
+
+ws_manager = None
+
+
+class RedisPubSubManager:
+ def __init__(self, host="localhost", port=6379):
+ self.redis_host = host
+ self.redis_port = port
+ self.redis_connection = None
+ self.pubsub = None
+
+ async def get_redis_connection(self) -> redis.Redis:
+ return redis.Redis(
+ host=self.redis_host,
+ port=self.redis_port,
+ auto_close_connection_pool=False,
+ )
+
+ async def connect(self) -> None:
+ self.redis_connection = await self.get_redis_connection()
+ self.pubsub = self.redis_connection.pubsub()
+
+ async def disconnect(self) -> None:
+ if self.redis_connection is None:
+ return
+ await self.redis_connection.close()
+ self.redis_connection = None
+
+ async def send_json(self, room_id: str, message: str) -> None:
+ message = json.dumps(message)
+ await self.redis_connection.publish(room_id, message)
+
+ async def subscribe(self, room_id: str) -> redis.Redis:
+ await self.pubsub.subscribe(room_id)
+ return self.pubsub
+
+ async def unsubscribe(self, room_id: str) -> None:
+ await self.pubsub.unsubscribe(room_id)
+
+
+class WebsocketManager:
+ def __init__(self, pubsub_client: RedisPubSubManager = None):
+ self.rooms: dict = {}
+ self.pubsub_client = pubsub_client
+
+ async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
+ await websocket.accept()
+
+ if room_id in self.rooms:
+ self.rooms[room_id].append(websocket)
+ else:
+ self.rooms[room_id] = [websocket]
+
+ await self.pubsub_client.connect()
+ pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
+ asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
+
+ async def send_json(self, room_id: str, message: dict) -> None:
+ await self.pubsub_client.send_json(room_id, message)
+
+ async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
+ self.rooms[room_id].remove(websocket)
+
+ if len(self.rooms[room_id]) == 0:
+ del self.rooms[room_id]
+ await self.pubsub_client.unsubscribe(room_id)
+
+ async def _pubsub_data_reader(self, pubsub_subscriber):
+ while True:
+ message = await pubsub_subscriber.get_message(
+ ignore_subscribe_messages=True
+ )
+ if message is not None:
+ room_id = message["channel"].decode("utf-8")
+ all_sockets = self.rooms[room_id]
+ for socket in all_sockets:
+ data = json.loads(message["data"].decode("utf-8"))
+ await socket.send_json(data)
+
+
+def get_pubsub_client() -> RedisPubSubManager:
+ """
+ Returns the RedisPubSubManager instance for managing Redis pubsub.
+ """
+ from reflector.settings import settings
+
+ return RedisPubSubManager(
+ host=settings.REDIS_HOST,
+ port=settings.REDIS_PORT,
+ )
+
+
+def get_ws_manager() -> WebsocketManager:
+ """
+ Returns the WebsocketManager instance for managing websockets.
+
+ This function initializes and returns the WebsocketManager instance,
+ which is responsible for managing websockets and handling websocket
+ connections.
+
+ Returns:
+ WebsocketManager: The initialized WebsocketManager instance.
+
+ Raises:
+ ImportError: If the 'reflector.settings' module cannot be imported.
+ RedisConnectionError: If there is an error connecting to the Redis server.
+ """
+ global ws_manager
+ pubsub_client = get_pubsub_client()
+ ws_manager = WebsocketManager(pubsub_client=pubsub_client)
+ return ws_manager
From 367912869d2a0834dbb4d4ccc8d9793dbf119d2f Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 25 Oct 2023 19:49:15 +0200
Subject: [PATCH 11/41] server: make processors in broadcast to be executed in
parallel
---
server/reflector/processors/base.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py
index 6771e11e..46bfb4a5 100644
--- a/server/reflector/processors/base.py
+++ b/server/reflector/processors/base.py
@@ -290,12 +290,12 @@ class BroadcastProcessor(Processor):
processor.set_pipeline(pipeline)
async def _push(self, data):
- for processor in self.processors:
- await processor.push(data)
+ coros = [processor.push(data) for processor in self.processors]
+ await asyncio.gather(*coros)
async def _flush(self):
- for processor in self.processors:
- await processor.flush()
+ coros = [processor.flush() for processor in self.processors]
+ await asyncio.gather(*coros)
def connect(self, processor: Processor):
for processor in self.processors:
@@ -333,6 +333,7 @@ class Pipeline(Processor):
self.logger.info("Pipeline created")
self.processors = processors
+ self.options = None
self.prefs = {}
for processor in processors:
From a45b30ee70d080fead2d6ae82ee1251c729a35bf Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 25 Oct 2023 19:50:08 +0200
Subject: [PATCH 12/41] www: ensure login waited before recording
if you refresh the record page, it does not work and return 404 because the transcript is accessed without token
---
.../[domain]/transcripts/[transcriptId]/page.tsx | 2 +-
.../transcripts/[transcriptId]/record/page.tsx | 15 ++++++++++++---
2 files changed, 13 insertions(+), 4 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index d4f40428..3e30b97f 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -35,7 +35,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
useEffect(() => {
if (requireLogin && !isAuthenticated) return;
setTranscriptId(details.params.transcriptId);
- }, [api]);
+ }, [api, details.params.transcriptId, isAuthenticated]);
if (transcript?.error /** || topics?.error || waveform?.error **/) {
return (
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
index 8e31327c..51a318a4 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
@@ -14,6 +14,8 @@ import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faGear } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
+import { featRequireLogin } from "../../../../../app/lib/utils";
+import { useFiefIsAuthenticated } from "@fief/fief/nextjs/react";
type TranscriptDetails = {
params: {
@@ -36,16 +38,23 @@ const TranscriptRecord = (details: TranscriptDetails) => {
}
}, []);
+ const isAuthenticated = useFiefIsAuthenticated();
const api = getApi();
- const transcript = useTranscript(api, details.params.transcriptId);
- const webRTC = useWebRTC(stream, details.params.transcriptId, api);
- const webSockets = useWebSockets(details.params.transcriptId);
+ const [transcriptId, setTranscriptId] = useState("");
+ const transcript = useTranscript(api, transcriptId);
+ const webRTC = useWebRTC(stream, transcriptId, api);
+ const webSockets = useWebSockets(transcriptId);
const { audioDevices, getAudioStream } = useAudioDevice();
const [hasRecorded, setHasRecorded] = useState(false);
const [transcriptStarted, setTranscriptStarted] = useState(false);
+ useEffect(() => {
+ if (featRequireLogin() && !isAuthenticated) return;
+ setTranscriptId(details.params.transcriptId);
+ }, [api, details.params.transcriptId, isAuthenticated]);
+
useEffect(() => {
if (!transcriptStarted && webSockets.transcriptText.length !== 0)
setTranscriptStarted(true);
From 433c0500ccc137c56ec11284df46a679699c5982 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 25 Oct 2023 19:50:27 +0200
Subject: [PATCH 13/41] server: refactor to separate websocket management +
start pipeline runner
---
server/reflector/views/rtc_offer.py | 160 ++++++++++++++++++++------
server/reflector/views/transcripts.py | 17 ++-
server/reflector/ws_manager.py | 23 ++--
3 files changed, 148 insertions(+), 52 deletions(-)
diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py
index 5662d989..48d804cc 100644
--- a/server/reflector/views/rtc_offer.py
+++ b/server/reflector/views/rtc_offer.py
@@ -2,6 +2,7 @@ import asyncio
from enum import StrEnum
from json import dumps, loads
from pathlib import Path
+from typing import Callable
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -38,7 +39,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions")
class TranscriptionContext(object):
def __init__(self, logger):
self.logger = logger
- self.pipeline = None
+ self.pipeline_runner = None
self.data_channel = None
self.status = "idle"
self.topics = []
@@ -60,7 +61,7 @@ class AudioStreamTrack(MediaStreamTrack):
ctx = self.ctx
frame = await self.track.recv()
try:
- await ctx.pipeline.push(frame)
+ await ctx.pipeline_runner.push(frame)
except Exception as e:
ctx.logger.error("Pipeline error", error=e)
return frame
@@ -84,6 +85,113 @@ class PipelineEvent(StrEnum):
FINAL_TITLE = "FINAL_TITLE"
+class PipelineOptions(BaseModel):
+ audio_filename: Path | None = None
+ source_language: str = "en"
+ target_language: str = "en"
+
+ on_transcript: Callable | None = None
+ on_topic: Callable | None = None
+ on_final_title: Callable | None = None
+ on_final_short_summary: Callable | None = None
+ on_final_long_summary: Callable | None = None
+
+
+class PipelineRunner(object):
+ """
+ Pipeline runner designed to be executed in a asyncio task
+ """
+
+ def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
+ self.pipeline = pipeline
+ self.q_cmd = asyncio.Queue()
+ self.ev_done = asyncio.Event()
+ self.status = "idle"
+ self.status_callback = status_callback
+
+ async def update_status(self, status):
+ print("update_status", status)
+ self.status = status
+ if self.status_callback:
+ try:
+ await self.status_callback(status)
+ except Exception as e:
+ logger.error("PipelineRunner status_callback error", error=e)
+
+ async def add_cmd(self, cmd: str, data):
+ await self.q_cmd.put([cmd, data])
+
+ async def push(self, data):
+ await self.add_cmd("PUSH", data)
+
+ async def flush(self):
+ await self.add_cmd("FLUSH", None)
+
+ async def run(self):
+ try:
+ await self.update_status("running")
+ while not self.ev_done.is_set():
+ cmd, data = await self.q_cmd.get()
+ func = getattr(self, f"cmd_{cmd.lower()}")
+ if func:
+ await func(data)
+ else:
+ raise Exception(f"Unknown command {cmd}")
+ except Exception as e:
+ await self.update_status("error")
+ logger.error("PipelineRunner error", error=e)
+
+ async def cmd_push(self, data):
+ if self.status == "idle":
+ await self.update_status("recording")
+ await self.pipeline.push(data)
+
+ async def cmd_flush(self, data):
+ await self.update_status("processing")
+ await self.pipeline.flush()
+ await self.update_status("ended")
+ self.ev_done.set()
+
+ def start(self):
+ print("start task")
+ asyncio.get_event_loop().create_task(self.run())
+
+
+async def pipeline_live_create(options: PipelineOptions):
+ # create a context for the whole rtc transaction
+ # add a customised logger to the context
+ processors = []
+ if options.audio_filename is not None:
+ processors += [AudioFileWriterProcessor(path=options.audio_filename)]
+ processors += [
+ AudioChunkerProcessor(),
+ AudioMergeProcessor(),
+ AudioTranscriptAutoProcessor.as_threaded(),
+ TranscriptLinerProcessor(),
+ TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
+ TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalTitleProcessor.as_threaded(
+ callback=options.on_final_title
+ ),
+ TranscriptFinalLongSummaryProcessor.as_threaded(
+ callback=options.on_final_long_summary
+ ),
+ TranscriptFinalShortSummaryProcessor.as_threaded(
+ callback=options.on_final_short_summary
+ ),
+ ]
+ ),
+ ]
+ pipeline = Pipeline(*processors)
+ pipeline.options = options
+ pipeline.set_pref("audio:source_language", options.source_language)
+ pipeline.set_pref("audio:target_language", options.target_language)
+
+ return pipeline
+
+
async def rtc_offer_base(
params: RtcOffer,
request: Request,
@@ -211,37 +319,24 @@ async def rtc_offer_base(
data=title,
)
- # create a context for the whole rtc transaction
- # add a customised logger to the context
- processors = []
- if audio_filename is not None:
- processors += [AudioFileWriterProcessor(path=audio_filename)]
- processors += [
- AudioChunkerProcessor(),
- AudioMergeProcessor(),
- AudioTranscriptAutoProcessor.as_threaded(),
- TranscriptLinerProcessor(),
- TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
- TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
- BroadcastProcessor(
- processors=[
- TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
- TranscriptFinalLongSummaryProcessor.as_threaded(
- callback=on_final_long_summary
- ),
- TranscriptFinalShortSummaryProcessor.as_threaded(
- callback=on_final_short_summary
- ),
- ]
- ),
- ]
- ctx.pipeline = Pipeline(*processors)
- ctx.pipeline.set_pref("audio:source_language", source_language)
- ctx.pipeline.set_pref("audio:target_language", target_language)
-
# handle RTC peer connection
pc = RTCPeerConnection()
+ # create pipeline
+ options = PipelineOptions(
+ audio_filename=audio_filename,
+ source_language=source_language,
+ target_language=target_language,
+ on_transcript=on_transcript,
+ on_topic=on_topic,
+ on_final_short_summary=on_final_short_summary,
+ on_final_long_summary=on_final_long_summary,
+ on_final_title=on_final_title,
+ )
+ pipeline = await pipeline_live_create(options)
+ ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
+ ctx.pipeline_runner.start()
+
async def flush_pipeline_and_quit(close=True):
# may be called twice
# 1. either the client ask to sotp the meeting
@@ -249,12 +344,10 @@ async def rtc_offer_base(
# - when we receive the close event, we do nothing.
# 2. or the client close the connection
# and there is nothing to do because it is already closed
- await update_status("processing")
- await ctx.pipeline.flush()
+ await ctx.pipeline_runner.flush()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
- await update_status("ended")
if pc in sessions:
sessions.remove(pc)
m_rtc_sessions.dec()
@@ -287,7 +380,6 @@ async def rtc_offer_base(
def on_track(track):
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, track))
- asyncio.get_event_loop().create_task(update_status("recording"))
await pc.setRemoteDescription(offer)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 9480461f..9f02eb6d 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -483,16 +483,16 @@ async def transcript_get_topics(
]
-@router.get("/transcripts/{transcript_id}/events")
-async def transcript_get_websocket_events(transcript_id: str):
- pass
-
-
# ==============================================================
# Websocket
# ==============================================================
+@router.get("/transcripts/{transcript_id}/events")
+async def transcript_get_websocket_events(transcript_id: str):
+ pass
+
+
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
@@ -512,6 +512,13 @@ async def transcript_events_websocket(
try:
# on first connection, send all events only to the current user
for event in transcript.events:
+ # for now, do not send TRANSCRIPT or STATUS options - theses are live event
+ # not necessary to be sent to the client; but keep the rest
+ name = event.event
+ if name == PipelineEvent.TRANSCRIPT:
+ continue
+ if name == PipelineEvent.STATUS:
+ continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py
index 43475c1d..1dfe9e3d 100644
--- a/server/reflector/ws_manager.py
+++ b/server/reflector/ws_manager.py
@@ -33,6 +33,8 @@ class RedisPubSubManager:
)
async def connect(self) -> None:
+ if self.redis_connection is not None:
+ return
self.redis_connection = await self.get_redis_connection()
self.pubsub = self.redis_connection.pubsub()
@@ -43,6 +45,8 @@ class RedisPubSubManager:
self.redis_connection = None
async def send_json(self, room_id: str, message: str) -> None:
+ if not self.redis_connection:
+ await self.connect()
message = json.dumps(message)
await self.redis_connection.publish(room_id, message)
@@ -94,18 +98,6 @@ class WebsocketManager:
await socket.send_json(data)
-def get_pubsub_client() -> RedisPubSubManager:
- """
- Returns the RedisPubSubManager instance for managing Redis pubsub.
- """
- from reflector.settings import settings
-
- return RedisPubSubManager(
- host=settings.REDIS_HOST,
- port=settings.REDIS_PORT,
- )
-
-
def get_ws_manager() -> WebsocketManager:
"""
Returns the WebsocketManager instance for managing websockets.
@@ -122,6 +114,11 @@ def get_ws_manager() -> WebsocketManager:
RedisConnectionError: If there is an error connecting to the Redis server.
"""
global ws_manager
- pubsub_client = get_pubsub_client()
+ from reflector.settings import settings
+
+ pubsub_client = RedisPubSubManager(
+ host=settings.REDIS_HOST,
+ port=settings.REDIS_PORT,
+ )
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
return ws_manager
From 1c42473da029bb2f88b090a7323ac4f68b818275 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 26 Oct 2023 19:00:56 +0200
Subject: [PATCH 14/41] server: refactor with clearer pipeline instanciation
and linked to model
---
server/reflector/db/__init__.py | 23 +-
server/reflector/db/transcripts.py | 284 +++++++++++++++
.../reflector/pipelines/main_live_pipeline.py | 230 ++++++++++++
server/reflector/pipelines/runner.py | 117 ++++++
server/reflector/processors/__init__.py | 8 +-
server/reflector/views/rtc_offer.py | 267 +-------------
server/reflector/views/transcripts.py | 341 +-----------------
server/reflector/ws_manager.py | 4 +-
8 files changed, 658 insertions(+), 616 deletions(-)
create mode 100644 server/reflector/db/transcripts.py
create mode 100644 server/reflector/pipelines/main_live_pipeline.py
create mode 100644 server/reflector/pipelines/runner.py
diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py
index b68dfe20..9871c633 100644
--- a/server/reflector/db/__init__.py
+++ b/server/reflector/db/__init__.py
@@ -1,32 +1,13 @@
import databases
import sqlalchemy
-
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
-
-transcripts = sqlalchemy.Table(
- "transcript",
- metadata,
- sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
- sqlalchemy.Column("name", sqlalchemy.String),
- sqlalchemy.Column("status", sqlalchemy.String),
- sqlalchemy.Column("locked", sqlalchemy.Boolean),
- sqlalchemy.Column("duration", sqlalchemy.Integer),
- sqlalchemy.Column("created_at", sqlalchemy.DateTime),
- sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
- sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
- sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
- sqlalchemy.Column("topics", sqlalchemy.JSON),
- sqlalchemy.Column("events", sqlalchemy.JSON),
- sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
- sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
- # with user attached, optional
- sqlalchemy.Column("user_id", sqlalchemy.String),
-)
+# import models
+import reflector.db.transcripts # noqa
engine = sqlalchemy.create_engine(
settings.DATABASE_URL, connect_args={"check_same_thread": False}
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
new file mode 100644
index 00000000..2b9fc6b2
--- /dev/null
+++ b/server/reflector/db/transcripts.py
@@ -0,0 +1,284 @@
+import json
+from contextlib import asynccontextmanager
+from datetime import datetime
+from pathlib import Path
+from typing import Any
+from uuid import uuid4
+
+import sqlalchemy
+from pydantic import BaseModel, Field
+from reflector.db import database, metadata
+from reflector.processors.types import Word as ProcessorWord
+from reflector.settings import settings
+from reflector.utils.audio_waveform import get_audio_waveform
+
+transcripts = sqlalchemy.Table(
+ "transcript",
+ metadata,
+ sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
+ sqlalchemy.Column("name", sqlalchemy.String),
+ sqlalchemy.Column("status", sqlalchemy.String),
+ sqlalchemy.Column("locked", sqlalchemy.Boolean),
+ sqlalchemy.Column("duration", sqlalchemy.Integer),
+ sqlalchemy.Column("created_at", sqlalchemy.DateTime),
+ sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
+ sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
+ sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
+ sqlalchemy.Column("topics", sqlalchemy.JSON),
+ sqlalchemy.Column("events", sqlalchemy.JSON),
+ sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
+ sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
+ # with user attached, optional
+ sqlalchemy.Column("user_id", sqlalchemy.String),
+)
+
+
+def generate_uuid4():
+ return str(uuid4())
+
+
+def generate_transcript_name():
+ now = datetime.utcnow()
+ return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
+
+
+class AudioWaveform(BaseModel):
+ data: list[float]
+
+
+class TranscriptText(BaseModel):
+ text: str
+ translation: str | None
+
+
+class TranscriptSegmentTopic(BaseModel):
+ speaker: int
+ text: str
+ timestamp: float
+
+
+class TranscriptTopic(BaseModel):
+ id: str = Field(default_factory=generate_uuid4)
+ title: str
+ summary: str
+ timestamp: float
+ text: str | None = None
+ words: list[ProcessorWord] = []
+
+
+class TranscriptFinalShortSummary(BaseModel):
+ short_summary: str
+
+
+class TranscriptFinalLongSummary(BaseModel):
+ long_summary: str
+
+
+class TranscriptFinalTitle(BaseModel):
+ title: str
+
+
+class TranscriptEvent(BaseModel):
+ event: str
+ data: dict
+
+
+class Transcript(BaseModel):
+ id: str = Field(default_factory=generate_uuid4)
+ user_id: str | None = None
+ name: str = Field(default_factory=generate_transcript_name)
+ status: str = "idle"
+ locked: bool = False
+ duration: float = 0
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ title: str | None = None
+ short_summary: str | None = None
+ long_summary: str | None = None
+ topics: list[TranscriptTopic] = []
+ events: list[TranscriptEvent] = []
+ source_language: str = "en"
+ target_language: str = "en"
+
+ def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
+ ev = TranscriptEvent(event=event, data=data.model_dump())
+ self.events.append(ev)
+ return ev
+
+ def upsert_topic(self, topic: TranscriptTopic):
+ existing_topic = next((t for t in self.topics if t.id == topic.id), None)
+ if existing_topic:
+ existing_topic.update_from(topic)
+ else:
+ self.topics.append(topic)
+
+ def events_dump(self, mode="json"):
+ return [event.model_dump(mode=mode) for event in self.events]
+
+ def topics_dump(self, mode="json"):
+ return [topic.model_dump(mode=mode) for topic in self.topics]
+
+ def convert_audio_to_waveform(self, segments_count=256):
+ fn = self.audio_waveform_filename
+ if fn.exists():
+ return
+ waveform = get_audio_waveform(
+ path=self.audio_mp3_filename, segments_count=segments_count
+ )
+ try:
+ with open(fn, "w") as fd:
+ json.dump(waveform, fd)
+ except Exception:
+ # remove file if anything happen during the write
+ fn.unlink(missing_ok=True)
+ raise
+ return waveform
+
+ def unlink(self):
+ self.data_path.unlink(missing_ok=True)
+
+ @property
+ def data_path(self):
+ return Path(settings.DATA_DIR) / self.id
+
+ @property
+ def audio_mp3_filename(self):
+ return self.data_path / "audio.mp3"
+
+ @property
+ def audio_waveform_filename(self):
+ return self.data_path / "audio.json"
+
+ @property
+ def audio_waveform(self):
+ try:
+ with open(self.audio_waveform_filename) as fd:
+ data = json.load(fd)
+ except json.JSONDecodeError:
+ # unlink file if it's corrupted
+ self.audio_waveform_filename.unlink(missing_ok=True)
+ return None
+
+ return AudioWaveform(data=data)
+
+
+class TranscriptController:
+ async def get_all(
+ self,
+ user_id: str | None = None,
+ order_by: str | None = None,
+ filter_empty: bool | None = True,
+ filter_recording: bool | None = True,
+ ) -> list[Transcript]:
+ """
+ Get all transcripts
+
+ If `user_id` is specified, only return transcripts that belong to the user.
+ Otherwise, return all anonymous transcripts.
+
+ Parameters:
+ - `order_by`: field to order by, e.g. "-created_at"
+ - `filter_empty`: filter out empty transcripts
+ - `filter_recording`: filter out transcripts that are currently recording
+ """
+ query = transcripts.select().where(transcripts.c.user_id == user_id)
+
+ if order_by is not None:
+ field = getattr(transcripts.c, order_by[1:])
+ if order_by.startswith("-"):
+ field = field.desc()
+ query = query.order_by(field)
+
+ if filter_empty:
+ query = query.filter(transcripts.c.status != "idle")
+
+ if filter_recording:
+ query = query.filter(transcripts.c.status != "recording")
+
+ results = await database.fetch_all(query)
+ return results
+
+ async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
+ """
+ Get a transcript by id
+ """
+ query = transcripts.select().where(transcripts.c.id == transcript_id)
+ if "user_id" in kwargs:
+ query = query.where(transcripts.c.user_id == kwargs["user_id"])
+ result = await database.fetch_one(query)
+ if not result:
+ return None
+ return Transcript(**result)
+
+ async def add(
+ self,
+ name: str,
+ source_language: str = "en",
+ target_language: str = "en",
+ user_id: str | None = None,
+ ):
+ """
+ Add a new transcript
+ """
+ transcript = Transcript(
+ name=name,
+ source_language=source_language,
+ target_language=target_language,
+ user_id=user_id,
+ )
+ query = transcripts.insert().values(**transcript.model_dump())
+ await database.execute(query)
+ return transcript
+
+ async def update(self, transcript: Transcript, values: dict):
+ """
+ Update a transcript fields with key/values in values
+ """
+ query = (
+ transcripts.update()
+ .where(transcripts.c.id == transcript.id)
+ .values(**values)
+ )
+ await database.execute(query)
+ for key, value in values.items():
+ setattr(transcript, key, value)
+
+ async def remove_by_id(
+ self,
+ transcript_id: str,
+ user_id: str | None = None,
+ ) -> None:
+ """
+ Remove a transcript by id
+ """
+ transcript = await self.get_by_id(transcript_id, user_id=user_id)
+ if not transcript:
+ return
+ if user_id is not None and transcript.user_id != user_id:
+ return
+ transcript.unlink()
+ query = transcripts.delete().where(transcripts.c.id == transcript_id)
+ await database.execute(query)
+
+ @asynccontextmanager
+ async def transaction(self):
+ """
+ A context manager for database transaction
+ """
+ async with database.transaction():
+ yield
+
+ async def append_event(
+ self,
+ transcript: Transcript,
+ event: str,
+ data: Any,
+ ) -> TranscriptEvent:
+ """
+ Append an event to a transcript
+ """
+ resp = transcript.add_event(event=event, data=data)
+ await self.update(transcript, {"events": transcript.events_dump()})
+ return resp
+
+
+transcripts_controller = TranscriptController()
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
new file mode 100644
index 00000000..30f7ead3
--- /dev/null
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -0,0 +1,230 @@
+"""
+Main reflector pipeline for live streaming
+==========================================
+
+This is the default pipeline used in the API.
+
+It is decoupled to:
+- PipelineMainLive: have limited processing during live
+- PipelineMainPost: do heavy lifting after the live
+
+It is directly linked to our data model.
+"""
+
+from pathlib import Path
+
+from reflector.db.transcripts import (
+ Transcript,
+ TranscriptFinalLongSummary,
+ TranscriptFinalShortSummary,
+ TranscriptFinalTitle,
+ TranscriptText,
+ TranscriptTopic,
+ transcripts_controller,
+)
+from reflector.pipelines.runner import PipelineRunner
+from reflector.processors import (
+ AudioChunkerProcessor,
+ AudioFileWriterProcessor,
+ AudioMergeProcessor,
+ AudioTranscriptAutoProcessor,
+ BroadcastProcessor,
+ Pipeline,
+ TranscriptFinalLongSummaryProcessor,
+ TranscriptFinalShortSummaryProcessor,
+ TranscriptFinalTitleProcessor,
+ TranscriptLinerProcessor,
+ TranscriptTopicDetectorProcessor,
+ TranscriptTranslatorProcessor,
+)
+from reflector.tasks.worker import celery
+from reflector.ws_manager import WebsocketManager, get_ws_manager
+
+
+def broadcast_to_socket(func):
+ """
+ Decorator to broadcast transcript event to websockets
+ concerning this transcript
+ """
+
+ async def wrapper(self, *args, **kwargs):
+ resp = await func(self, *args, **kwargs)
+ if resp is None:
+ return
+ await self.ws_manager.send_json(
+ room_id=self.ws_room_id,
+ message=resp.model_dump(mode="json"),
+ )
+
+ return wrapper
+
+
+class PipelineMainBase(PipelineRunner):
+ transcript_id: str
+ ws_room_id: str | None = None
+ ws_manager: WebsocketManager | None = None
+
+ def prepare(self):
+ # prepare websocket
+ self.ws_room_id = f"ts:{self.transcript_id}"
+ self.ws_manager = get_ws_manager()
+
+ async def get_transcript(self) -> Transcript:
+ # fetch the transcript
+ result = await transcripts_controller.get_by_id(
+ transcript_id=self.transcript_id
+ )
+ if not result:
+ raise Exception("Transcript not found")
+ return result
+
+
+class PipelineMainLive(PipelineMainBase):
+ audio_filename: Path | None = None
+ source_language: str = "en"
+ target_language: str = "en"
+
+ @broadcast_to_socket
+ async def on_transcript(self, data):
+ async with transcripts_controller.transaction():
+ transcript = await self.get_transcript()
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="TRANSCRIPT",
+ data=TranscriptText(text=data.text, translation=data.translation),
+ )
+
+ @broadcast_to_socket
+ async def on_topic(self, data):
+ topic = TranscriptTopic(
+ title=data.title,
+ summary=data.summary,
+ timestamp=data.timestamp,
+ text=data.transcript.text,
+ words=data.transcript.words,
+ )
+ async with transcripts_controller.transaction():
+ transcript = await self.get_transcript()
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="TOPIC",
+ data=topic,
+ )
+
+ async def create(self) -> Pipeline:
+ # create a context for the whole rtc transaction
+ # add a customised logger to the context
+ self.prepare()
+ transcript = await self.get_transcript()
+
+ processors = [
+ AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
+ AudioChunkerProcessor(),
+ AudioMergeProcessor(),
+ AudioTranscriptAutoProcessor.as_threaded(),
+ TranscriptLinerProcessor(),
+ TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
+ TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
+ ]
+ pipeline = Pipeline(*processors)
+ pipeline.options = self
+ pipeline.set_pref("audio:source_language", transcript.source_language)
+ pipeline.set_pref("audio:target_language", transcript.target_language)
+
+ # when the pipeline ends, connect to the post pipeline
+ async def on_ended():
+ task_pipeline_main_post.delay(transcript_id=self.transcript_id)
+
+ pipeline.on_ended = self
+
+ return pipeline
+
+
+class PipelineMainPost(PipelineMainBase):
+ """
+ Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
+ """
+
+ @broadcast_to_socket
+ async def on_final_title(self, data):
+ final_title = TranscriptFinalTitle(title=data.title)
+ async with transcripts_controller.transaction():
+ transcript = await self.get_transcript()
+ if not transcript.title:
+ transcripts_controller.update(
+ self.transcript,
+ {
+ "title": final_title.title,
+ },
+ )
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="FINAL_TITLE",
+ data=final_title,
+ )
+
+ @broadcast_to_socket
+ async def on_final_long_summary(self, data):
+ final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
+ async with transcripts_controller.transaction():
+ transcript = await self.get_transcript()
+ await transcripts_controller.update(
+ transcript,
+ {
+ "long_summary": final_long_summary.long_summary,
+ },
+ )
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="FINAL_LONG_SUMMARY",
+ data=final_long_summary,
+ )
+
+ @broadcast_to_socket
+ async def on_final_short_summary(self, data):
+ final_short_summary = TranscriptFinalShortSummary(
+ short_summary=data.short_summary
+ )
+ async with transcripts_controller.transaction():
+ transcript = await self.get_transcript()
+ await transcripts_controller.update(
+ transcript,
+ {
+ "short_summary": final_short_summary.short_summary,
+ },
+ )
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="FINAL_SHORT_SUMMARY",
+ data=final_short_summary,
+ )
+
+ async def create(self) -> Pipeline:
+ # create a context for the whole rtc transaction
+ # add a customised logger to the context
+ self.prepare()
+ processors = [
+ # add diarization
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalTitleProcessor.as_threaded(
+ callback=self.on_final_title
+ ),
+ TranscriptFinalLongSummaryProcessor.as_threaded(
+ callback=self.on_final_long_summary
+ ),
+ TranscriptFinalShortSummaryProcessor.as_threaded(
+ callback=self.on_final_short_summary
+ ),
+ ]
+ ),
+ ]
+ pipeline = Pipeline(*processors)
+ pipeline.options = self
+
+ return pipeline
+
+
+@celery.task
+def task_pipeline_main_post(transcript_id: str):
+ pass
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
new file mode 100644
index 00000000..ce84fec4
--- /dev/null
+++ b/server/reflector/pipelines/runner.py
@@ -0,0 +1,117 @@
+"""
+Pipeline Runner
+===============
+
+Pipeline runner designed to be executed in a asyncio task.
+
+It is meant to be subclassed, and implement a create() method
+that expose/return a Pipeline instance.
+
+During its lifecycle, it will emit the following status:
+- started: the pipeline has been started
+- push: the pipeline received at least one data
+- flush: the pipeline is flushing
+- ended: the pipeline has ended
+- error: the pipeline has ended with an error
+"""
+
+import asyncio
+from typing import Callable
+
+from pydantic import BaseModel, ConfigDict
+from reflector.logger import logger
+from reflector.processors import Pipeline
+
+
+class PipelineRunner(BaseModel):
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ status: str = "idle"
+ on_status: Callable | None = None
+ on_ended: Callable | None = None
+ pipeline: Pipeline | None = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._q_cmd = asyncio.Queue()
+ self._ev_done = asyncio.Event()
+ self._is_first_push = True
+
+ def create(self) -> Pipeline:
+ """
+ Create the pipeline if not specified earlier.
+ Should be implemented in a subclass
+ """
+ raise NotImplementedError()
+
+ def start(self):
+ """
+ Start the pipeline as a coroutine task
+ """
+ asyncio.get_event_loop().create_task(self.run())
+
+ async def push(self, data):
+ """
+ Push data to the pipeline
+ """
+ await self._add_cmd("PUSH", data)
+
+ async def flush(self):
+ """
+ Flush the pipeline
+ """
+ await self._add_cmd("FLUSH", None)
+
+ async def _add_cmd(self, cmd: str, data):
+ """
+ Enqueue a command to be executed in the runner.
+ Currently supported commands: PUSH, FLUSH
+ """
+ await self._q_cmd.put([cmd, data])
+
+ async def _set_status(self, status):
+ print("set_status", status)
+ self.status = status
+ if self.on_status:
+ try:
+ await self.on_status(status)
+ except Exception as e:
+ logger.error("PipelineRunner status_callback error", error=e)
+
+ async def run(self):
+ try:
+ # create the pipeline if not yet done
+ await self._set_status("init")
+ self._is_first_push = True
+ if not self.pipeline:
+ self.pipeline = await self.create()
+
+ # start the loop
+ await self._set_status("started")
+ while not self._ev_done.is_set():
+ cmd, data = await self._q_cmd.get()
+ func = getattr(self, f"cmd_{cmd.lower()}")
+ if func:
+ await func(data)
+ else:
+ raise Exception(f"Unknown command {cmd}")
+ except Exception as e:
+ logger.error("PipelineRunner error", error=e)
+ await self._set_status("error")
+ self._ev_done.set()
+ if self.on_ended:
+ await self.on_ended()
+
+ async def cmd_push(self, data):
+ if self._is_first_push:
+ await self._set_status("push")
+ self._is_first_push = False
+ await self.pipeline.push(data)
+
+ async def cmd_flush(self, data):
+ await self._set_status("flush")
+ await self.pipeline.flush()
+ await self._set_status("ended")
+ self._ev_done.set()
+ if self.on_ended:
+ await self.on_ended()
diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py
index 96a3941d..960c6a35 100644
--- a/server/reflector/processors/__init__.py
+++ b/server/reflector/processors/__init__.py
@@ -3,7 +3,13 @@ from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
-from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
+from .base import ( # noqa: F401
+ BroadcastProcessor,
+ Pipeline,
+ PipelineEvent,
+ Processor,
+ ThreadedProcessor,
+)
from .transcript_final_long_summary import ( # noqa: F401
TranscriptFinalLongSummaryProcessor,
)
diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py
index 48d804cc..5d10c181 100644
--- a/server/reflector/views/rtc_offer.py
+++ b/server/reflector/views/rtc_offer.py
@@ -1,8 +1,6 @@
import asyncio
from enum import StrEnum
-from json import dumps, loads
-from pathlib import Path
-from typing import Callable
+from json import loads
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
@@ -11,25 +9,7 @@ from prometheus_client import Gauge
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
-from reflector.processors import (
- AudioChunkerProcessor,
- AudioFileWriterProcessor,
- AudioMergeProcessor,
- AudioTranscriptAutoProcessor,
- FinalLongSummary,
- FinalShortSummary,
- Pipeline,
- TitleSummary,
- Transcript,
- TranscriptFinalLongSummaryProcessor,
- TranscriptFinalShortSummaryProcessor,
- TranscriptFinalTitleProcessor,
- TranscriptLinerProcessor,
- TranscriptTopicDetectorProcessor,
- TranscriptTranslatorProcessor,
-)
-from reflector.processors.base import BroadcastProcessor
-from reflector.processors.types import FinalTitle
+from reflector.pipelines.runner import PipelineRunner
sessions = []
router = APIRouter()
@@ -85,121 +65,10 @@ class PipelineEvent(StrEnum):
FINAL_TITLE = "FINAL_TITLE"
-class PipelineOptions(BaseModel):
- audio_filename: Path | None = None
- source_language: str = "en"
- target_language: str = "en"
-
- on_transcript: Callable | None = None
- on_topic: Callable | None = None
- on_final_title: Callable | None = None
- on_final_short_summary: Callable | None = None
- on_final_long_summary: Callable | None = None
-
-
-class PipelineRunner(object):
- """
- Pipeline runner designed to be executed in a asyncio task
- """
-
- def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None):
- self.pipeline = pipeline
- self.q_cmd = asyncio.Queue()
- self.ev_done = asyncio.Event()
- self.status = "idle"
- self.status_callback = status_callback
-
- async def update_status(self, status):
- print("update_status", status)
- self.status = status
- if self.status_callback:
- try:
- await self.status_callback(status)
- except Exception as e:
- logger.error("PipelineRunner status_callback error", error=e)
-
- async def add_cmd(self, cmd: str, data):
- await self.q_cmd.put([cmd, data])
-
- async def push(self, data):
- await self.add_cmd("PUSH", data)
-
- async def flush(self):
- await self.add_cmd("FLUSH", None)
-
- async def run(self):
- try:
- await self.update_status("running")
- while not self.ev_done.is_set():
- cmd, data = await self.q_cmd.get()
- func = getattr(self, f"cmd_{cmd.lower()}")
- if func:
- await func(data)
- else:
- raise Exception(f"Unknown command {cmd}")
- except Exception as e:
- await self.update_status("error")
- logger.error("PipelineRunner error", error=e)
-
- async def cmd_push(self, data):
- if self.status == "idle":
- await self.update_status("recording")
- await self.pipeline.push(data)
-
- async def cmd_flush(self, data):
- await self.update_status("processing")
- await self.pipeline.flush()
- await self.update_status("ended")
- self.ev_done.set()
-
- def start(self):
- print("start task")
- asyncio.get_event_loop().create_task(self.run())
-
-
-async def pipeline_live_create(options: PipelineOptions):
- # create a context for the whole rtc transaction
- # add a customised logger to the context
- processors = []
- if options.audio_filename is not None:
- processors += [AudioFileWriterProcessor(path=options.audio_filename)]
- processors += [
- AudioChunkerProcessor(),
- AudioMergeProcessor(),
- AudioTranscriptAutoProcessor.as_threaded(),
- TranscriptLinerProcessor(),
- TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript),
- TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic),
- BroadcastProcessor(
- processors=[
- TranscriptFinalTitleProcessor.as_threaded(
- callback=options.on_final_title
- ),
- TranscriptFinalLongSummaryProcessor.as_threaded(
- callback=options.on_final_long_summary
- ),
- TranscriptFinalShortSummaryProcessor.as_threaded(
- callback=options.on_final_short_summary
- ),
- ]
- ),
- ]
- pipeline = Pipeline(*processors)
- pipeline.options = options
- pipeline.set_pref("audio:source_language", options.source_language)
- pipeline.set_pref("audio:target_language", options.target_language)
-
- return pipeline
-
-
async def rtc_offer_base(
params: RtcOffer,
request: Request,
- event_callback=None,
- event_callback_args=None,
- audio_filename: Path | None = None,
- source_language: str = "en",
- target_language: str = "en",
+ pipeline_runner: PipelineRunner,
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -209,132 +78,9 @@ async def rtc_offer_base(
clientid = f"{peername[0]}:{peername[1]}"
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
- async def update_status(status: str):
- changed = ctx.status != status
- if changed:
- ctx.status = status
- if event_callback:
- await event_callback(
- event=PipelineEvent.STATUS,
- args=event_callback_args,
- data=StrValue(value=status),
- )
-
- # build pipeline callback
- async def on_transcript(transcript: Transcript):
- ctx.logger.info("Transcript", transcript=transcript)
-
- # send to RTC
- if ctx.data_channel.readyState == "open":
- result = {
- "cmd": "SHOW_TRANSCRIPTION",
- "text": transcript.text,
- }
- ctx.data_channel.send(dumps(result))
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.TRANSCRIPT,
- args=event_callback_args,
- data=transcript,
- )
-
- async def on_topic(topic: TitleSummary):
- # FIXME: make it incremental with the frontend, not send everything
- ctx.logger.info("Topic", topic=topic)
- ctx.topics.append(
- {
- "title": topic.title,
- "timestamp": topic.timestamp,
- "transcript": topic.transcript.text,
- "desc": topic.summary,
- }
- )
-
- # send to RTC
- if ctx.data_channel.readyState == "open":
- result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics}
- ctx.data_channel.send(dumps(result))
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
- )
-
- async def on_final_short_summary(summary: FinalShortSummary):
- ctx.logger.info("FinalShortSummary", final_short_summary=summary)
-
- # send to RTC
- if ctx.data_channel.readyState == "open":
- result = {
- "cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
- "summary": summary.short_summary,
- "duration": summary.duration,
- }
- ctx.data_channel.send(dumps(result))
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.FINAL_SHORT_SUMMARY,
- args=event_callback_args,
- data=summary,
- )
-
- async def on_final_long_summary(summary: FinalLongSummary):
- ctx.logger.info("FinalLongSummary", final_summary=summary)
-
- # send to RTC
- if ctx.data_channel.readyState == "open":
- result = {
- "cmd": "DISPLAY_FINAL_LONG_SUMMARY",
- "summary": summary.long_summary,
- "duration": summary.duration,
- }
- ctx.data_channel.send(dumps(result))
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.FINAL_LONG_SUMMARY,
- args=event_callback_args,
- data=summary,
- )
-
- async def on_final_title(title: FinalTitle):
- ctx.logger.info("FinalTitle", final_title=title)
-
- # send to RTC
- if ctx.data_channel.readyState == "open":
- result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
- ctx.data_channel.send(dumps(result))
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.FINAL_TITLE,
- args=event_callback_args,
- data=title,
- )
-
# handle RTC peer connection
pc = RTCPeerConnection()
-
- # create pipeline
- options = PipelineOptions(
- audio_filename=audio_filename,
- source_language=source_language,
- target_language=target_language,
- on_transcript=on_transcript,
- on_topic=on_topic,
- on_final_short_summary=on_final_short_summary,
- on_final_long_summary=on_final_long_summary,
- on_final_title=on_final_title,
- )
- pipeline = await pipeline_live_create(options)
- ctx.pipeline_runner = PipelineRunner(pipeline, update_status)
+ ctx.pipeline_runner = pipeline_runner
ctx.pipeline_runner.start()
async def flush_pipeline_and_quit(close=True):
@@ -400,8 +146,3 @@ async def rtc_clean_sessions(_):
logger.debug(f"Closing session {pc}")
await pc.close()
sessions.clear()
-
-
-@router.post("/offer")
-async def rtc_offer(params: RtcOffer, request: Request):
- return await rtc_offer_base(params, request)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 9f02eb6d..e949d645 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -1,8 +1,5 @@
-import json
from datetime import datetime
-from pathlib import Path
from typing import Annotated, Optional
-from uuid import uuid4
import reflector.auth as auth
from fastapi import (
@@ -15,12 +12,13 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
-from reflector.db import database, transcripts
-from reflector.logger import logger
+from reflector.db.transcripts import (
+ AudioWaveform,
+ TranscriptTopic,
+ transcripts_controller,
+)
from reflector.processors.types import Transcript as ProcessorTranscript
-from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
-from reflector.utils.audio_waveform import get_audio_waveform
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
@@ -30,216 +28,6 @@ from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
ws_manager = get_ws_manager()
-# ==============================================================
-# Models to move to a database, but required for the API to work
-# ==============================================================
-
-
-def generate_uuid4():
- return str(uuid4())
-
-
-def generate_transcript_name():
- now = datetime.utcnow()
- return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
-
-
-class AudioWaveform(BaseModel):
- data: list[float]
-
-
-class TranscriptText(BaseModel):
- text: str
- translation: str | None
-
-
-class TranscriptSegmentTopic(BaseModel):
- speaker: int
- text: str
- timestamp: float
-
-
-class TranscriptTopic(BaseModel):
- id: str = Field(default_factory=generate_uuid4)
- title: str
- summary: str
- timestamp: float
- text: str | None = None
- words: list[ProcessorWord] = []
-
-
-class TranscriptFinalShortSummary(BaseModel):
- short_summary: str
-
-
-class TranscriptFinalLongSummary(BaseModel):
- long_summary: str
-
-
-class TranscriptFinalTitle(BaseModel):
- title: str
-
-
-class TranscriptEvent(BaseModel):
- event: str
- data: dict
-
-
-class Transcript(BaseModel):
- id: str = Field(default_factory=generate_uuid4)
- user_id: str | None = None
- name: str = Field(default_factory=generate_transcript_name)
- status: str = "idle"
- locked: bool = False
- duration: float = 0
- created_at: datetime = Field(default_factory=datetime.utcnow)
- title: str | None = None
- short_summary: str | None = None
- long_summary: str | None = None
- topics: list[TranscriptTopic] = []
- events: list[TranscriptEvent] = []
- source_language: str = "en"
- target_language: str = "en"
-
- def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
- ev = TranscriptEvent(event=event, data=data.model_dump())
- self.events.append(ev)
- return ev
-
- def upsert_topic(self, topic: TranscriptTopic):
- existing_topic = next((t for t in self.topics if t.id == topic.id), None)
- if existing_topic:
- existing_topic.update_from(topic)
- else:
- self.topics.append(topic)
-
- def events_dump(self, mode="json"):
- return [event.model_dump(mode=mode) for event in self.events]
-
- def topics_dump(self, mode="json"):
- return [topic.model_dump(mode=mode) for topic in self.topics]
-
- def convert_audio_to_waveform(self, segments_count=256):
- fn = self.audio_waveform_filename
- if fn.exists():
- return
- waveform = get_audio_waveform(
- path=self.audio_mp3_filename, segments_count=segments_count
- )
- try:
- with open(fn, "w") as fd:
- json.dump(waveform, fd)
- except Exception:
- # remove file if anything happen during the write
- fn.unlink(missing_ok=True)
- raise
- return waveform
-
- def unlink(self):
- self.data_path.unlink(missing_ok=True)
-
- @property
- def data_path(self):
- return Path(settings.DATA_DIR) / self.id
-
- @property
- def audio_mp3_filename(self):
- return self.data_path / "audio.mp3"
-
- @property
- def audio_waveform_filename(self):
- return self.data_path / "audio.json"
-
- @property
- def audio_waveform(self):
- try:
- with open(self.audio_waveform_filename) as fd:
- data = json.load(fd)
- except json.JSONDecodeError:
- # unlink file if it's corrupted
- self.audio_waveform_filename.unlink(missing_ok=True)
- return None
-
- return AudioWaveform(data=data)
-
-
-class TranscriptController:
- async def get_all(
- self,
- user_id: str | None = None,
- order_by: str | None = None,
- filter_empty: bool | None = False,
- filter_recording: bool | None = False,
- ) -> list[Transcript]:
- query = transcripts.select().where(transcripts.c.user_id == user_id)
-
- if order_by is not None:
- field = getattr(transcripts.c, order_by[1:])
- if order_by.startswith("-"):
- field = field.desc()
- query = query.order_by(field)
-
- if filter_empty:
- query = query.filter(transcripts.c.status != "idle")
-
- if filter_recording:
- query = query.filter(transcripts.c.status != "recording")
-
- results = await database.fetch_all(query)
- return results
-
- async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
- query = transcripts.select().where(transcripts.c.id == transcript_id)
- if "user_id" in kwargs:
- query = query.where(transcripts.c.user_id == kwargs["user_id"])
- result = await database.fetch_one(query)
- if not result:
- return None
- return Transcript(**result)
-
- async def add(
- self,
- name: str,
- source_language: str = "en",
- target_language: str = "en",
- user_id: str | None = None,
- ):
- transcript = Transcript(
- name=name,
- source_language=source_language,
- target_language=target_language,
- user_id=user_id,
- )
- query = transcripts.insert().values(**transcript.model_dump())
- await database.execute(query)
- return transcript
-
- async def update(self, transcript: Transcript, values: dict):
- query = (
- transcripts.update()
- .where(transcripts.c.id == transcript.id)
- .values(**values)
- )
- await database.execute(query)
- for key, value in values.items():
- setattr(transcript, key, value)
-
- async def remove_by_id(
- self, transcript_id: str, user_id: str | None = None
- ) -> None:
- transcript = await self.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- return
- if user_id is not None and transcript.user_id != user_id:
- return
- transcript.unlink()
- query = transcripts.delete().where(transcripts.c.id == transcript_id)
- await database.execute(query)
-
-
-transcripts_controller = TranscriptController()
-
-
# ==============================================================
# Transcripts list
# ==============================================================
@@ -537,114 +325,6 @@ async def transcript_events_websocket(
# ==============================================================
-async def handle_rtc_event(event: PipelineEvent, args, data):
- try:
- return await handle_rtc_event_once(event, args, data)
- except Exception:
- logger.exception("Error handling RTC event")
-
-
-async def handle_rtc_event_once(event: PipelineEvent, args, data):
- # OFC the current implementation is not good,
- # but it's just a POC before persistence. It won't query the
- # transcript from the database for each event.
- # print(f"Event: {event}", args, data)
- transcript_id = args
- transcript = await transcripts_controller.get_by_id(transcript_id)
- if not transcript:
- return
-
- # event send to websocket clients may not be the same as the event
- # received from the pipeline. For example, the pipeline will send
- # a TRANSCRIPT event with all words, but this is not what we want
- # to send to the websocket client.
-
- # FIXME don't do copy
- if event == PipelineEvent.TRANSCRIPT:
- resp = transcript.add_event(
- event=event,
- data=TranscriptText(text=data.text, translation=data.translation),
- )
- await transcripts_controller.update(
- transcript,
- {
- "events": transcript.events_dump(),
- },
- )
-
- elif event == PipelineEvent.TOPIC:
- topic = TranscriptTopic(
- title=data.title,
- summary=data.summary,
- timestamp=data.timestamp,
- text=data.transcript.text,
- words=data.transcript.words,
- )
- resp = transcript.add_event(event=event, data=topic)
- transcript.upsert_topic(topic)
-
- await transcripts_controller.update(
- transcript,
- {
- "events": transcript.events_dump(),
- "topics": transcript.topics_dump(),
- },
- )
-
- elif event == PipelineEvent.FINAL_TITLE:
- final_title = TranscriptFinalTitle(title=data.title)
- resp = transcript.add_event(event=event, data=final_title)
- await transcripts_controller.update(
- transcript,
- {
- "events": transcript.events_dump(),
- "title": final_title.title,
- },
- )
-
- elif event == PipelineEvent.FINAL_LONG_SUMMARY:
- final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
- resp = transcript.add_event(event=event, data=final_long_summary)
- await transcripts_controller.update(
- transcript,
- {
- "events": transcript.events_dump(),
- "long_summary": final_long_summary.long_summary,
- },
- )
-
- elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
- final_short_summary = TranscriptFinalShortSummary(
- short_summary=data.short_summary
- )
- resp = transcript.add_event(event=event, data=final_short_summary)
- await transcripts_controller.update(
- transcript,
- {
- "events": transcript.events_dump(),
- "short_summary": final_short_summary.short_summary,
- },
- )
-
- elif event == PipelineEvent.STATUS:
- resp = transcript.add_event(event=event, data=data)
- await transcripts_controller.update(
- transcript,
- {
- "events": transcript.events_dump(),
- "status": data.value,
- },
- )
-
- else:
- logger.warning(f"Unknown event: {event}")
- return
-
- # transmit to websocket clients
- room_id = f"ts:{transcript_id}"
- await ws_manager.send_json(room_id, resp.model_dump(mode="json"))
-
-
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
@@ -660,13 +340,14 @@ async def transcript_record_webrtc(
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
+ # create a pipeline runner
+ from reflector.pipelines.main_live_pipeline import PipelineMainLive
+
+ pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
+
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
- event_callback=handle_rtc_event,
- event_callback_args=transcript_id,
- audio_filename=transcript.audio_mp3_filename,
- source_language=transcript.source_language,
- target_language=transcript.target_language,
+ pipeline_runner=pipeline_runner,
)
diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py
index 1dfe9e3d..7650807b 100644
--- a/server/reflector/ws_manager.py
+++ b/server/reflector/ws_manager.py
@@ -14,6 +14,7 @@ import json
import redis.asyncio as redis
from fastapi import WebSocket
+from reflector.settings import settings
ws_manager = None
@@ -114,7 +115,8 @@ def get_ws_manager() -> WebsocketManager:
RedisConnectionError: If there is an error connecting to the Redis server.
"""
global ws_manager
- from reflector.settings import settings
+ if ws_manager:
+ return ws_manager
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,
From 07c4d080c22fb896982be1da9928fc16ffd0b6c8 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 27 Oct 2023 15:59:27 +0200
Subject: [PATCH 15/41] server: refactor with diarization, logic works
---
server/poetry.lock | 16 +-
server/pyproject.toml | 1 +
server/reflector/db/transcripts.py | 14 +-
.../reflector/pipelines/main_live_pipeline.py | 256 ++++++++++++------
server/reflector/pipelines/runner.py | 47 +++-
server/reflector/processors/__init__.py | 1 +
.../reflector/processors/audio_diarization.py | 65 +++++
server/reflector/processors/types.py | 5 +
server/reflector/tasks/boot.py | 2 -
server/reflector/tasks/worker.py | 6 -
server/reflector/views/rtc_offer.py | 18 +-
server/reflector/views/transcripts.py | 25 +-
server/reflector/worker/app.py | 11 +
.../{tasks => worker}/post_transcript.py | 0
server/reflector/ws_manager.py | 10 +-
server/tests/conftest.py | 25 +-
server/tests/test_transcripts_rtc_ws.py | 54 +++-
17 files changed, 387 insertions(+), 169 deletions(-)
create mode 100644 server/reflector/processors/audio_diarization.py
delete mode 100644 server/reflector/tasks/boot.py
delete mode 100644 server/reflector/tasks/worker.py
create mode 100644 server/reflector/worker/app.py
rename server/reflector/{tasks => worker}/post_transcript.py (100%)
diff --git a/server/poetry.lock b/server/poetry.lock
index 35d98382..8783625b 100644
--- a/server/poetry.lock
+++ b/server/poetry.lock
@@ -2676,6 +2676,20 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
+[[package]]
+name = "pytest-celery"
+version = "0.0.0"
+description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"},
+ {file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"},
+]
+
+[package.dependencies]
+celery = ">=4.4.0"
+
[[package]]
name = "pytest-cov"
version = "4.1.0"
@@ -4064,4 +4078,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63"
+content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d"
diff --git a/server/pyproject.toml b/server/pyproject.toml
index ed231a4f..c8614006 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -49,6 +49,7 @@ pytest-asyncio = "^0.21.1"
pytest = "^7.4.0"
httpx-ws = "^0.4.1"
pytest-httpx = "^0.23.1"
+pytest-celery = "^0.0.0"
[tool.poetry.group.aws.dependencies]
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 2b9fc6b2..61a2c380 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -62,6 +62,7 @@ class TranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
+ duration: float | None = 0
text: str | None = None
words: list[ProcessorWord] = []
@@ -264,7 +265,7 @@ class TranscriptController:
"""
A context manager for database transaction
"""
- async with database.transaction():
+ async with database.transaction(isolation="serializable"):
yield
async def append_event(
@@ -280,5 +281,16 @@ class TranscriptController:
await self.update(transcript, {"events": transcript.events_dump()})
return resp
+ async def upsert_topic(
+ self,
+ transcript: Transcript,
+ topic: TranscriptTopic,
+ ) -> TranscriptEvent:
+ """
+ Append an event to a transcript
+ """
+ transcript.upsert_topic(topic)
+ await self.update(transcript, {"topics": transcript.topics_dump()})
+
transcripts_controller = TranscriptController()
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 30f7ead3..4159c889 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -11,8 +11,12 @@ It is decoupled to:
It is directly linked to our data model.
"""
+import asyncio
+from contextlib import asynccontextmanager
from pathlib import Path
+from celery import shared_task
+from pydantic import BaseModel
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
@@ -25,6 +29,7 @@ from reflector.db.transcripts import (
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
+ AudioDiarizationProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
@@ -37,11 +42,13 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
-from reflector.tasks.worker import celery
+from reflector.processors.types import AudioDiarizationInput
+from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
+from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.ws_manager import WebsocketManager, get_ws_manager
-def broadcast_to_socket(func):
+def broadcast_to_sockets(func):
"""
Decorator to broadcast transcript event to websockets
concerning this transcript
@@ -59,6 +66,10 @@ def broadcast_to_socket(func):
return wrapper
+class StrValue(BaseModel):
+ value: str
+
+
class PipelineMainBase(PipelineRunner):
transcript_id: str
ws_room_id: str | None = None
@@ -66,6 +77,7 @@ class PipelineMainBase(PipelineRunner):
def prepare(self):
# prepare websocket
+ self._lock = asyncio.Lock()
self.ws_room_id = f"ts:{self.transcript_id}"
self.ws_manager = get_ws_manager()
@@ -78,15 +90,59 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
+ @asynccontextmanager
+ async def transaction(self):
+ async with self._lock:
+ async with transcripts_controller.transaction():
+ yield
-class PipelineMainLive(PipelineMainBase):
- audio_filename: Path | None = None
- source_language: str = "en"
- target_language: str = "en"
+ @broadcast_to_sockets
+ async def on_status(self, status):
+ # if it's the first part, update the status of the transcript
+ # but do not set the ended status yet.
+ if isinstance(self, PipelineMainLive):
+ status_mapping = {
+ "started": "recording",
+ "push": "recording",
+ "flush": "processing",
+ "error": "error",
+ }
+ elif isinstance(self, PipelineMainDiarization):
+ status_mapping = {
+ "push": "processing",
+ "flush": "processing",
+ "error": "error",
+ "ended": "ended",
+ }
+ else:
+ raise Exception(f"Runner {self.__class__} is missing status mapping")
- @broadcast_to_socket
+ # mutate to model status
+ status = status_mapping.get(status)
+ if not status:
+ return
+
+ # when the status of the pipeline changes, update the transcript
+ async with self.transaction():
+ transcript = await self.get_transcript()
+ if status == transcript.status:
+ return
+ resp = await transcripts_controller.append_event(
+ transcript=transcript,
+ event="STATUS",
+ data=StrValue(value=status),
+ )
+ await transcripts_controller.update(
+ transcript,
+ {
+ "status": status,
+ },
+ )
+ return resp
+
+ @broadcast_to_sockets
async def on_transcript(self, data):
- async with transcripts_controller.transaction():
+ async with self.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript,
@@ -94,7 +150,7 @@ class PipelineMainLive(PipelineMainBase):
data=TranscriptText(text=data.text, translation=data.translation),
)
- @broadcast_to_socket
+ @broadcast_to_sockets
async def on_topic(self, data):
topic = TranscriptTopic(
title=data.title,
@@ -103,14 +159,75 @@ class PipelineMainLive(PipelineMainBase):
text=data.transcript.text,
words=data.transcript.words,
)
- async with transcripts_controller.transaction():
+ async with self.transaction():
transcript = await self.get_transcript()
+ await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event(
transcript=transcript,
event="TOPIC",
data=topic,
)
+ @broadcast_to_sockets
+ async def on_title(self, data):
+ final_title = TranscriptFinalTitle(title=data.title)
+ async with self.transaction():
+ transcript = await self.get_transcript()
+ if not transcript.title:
+ transcripts_controller.update(
+ transcript,
+ {
+ "title": final_title.title,
+ },
+ )
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="FINAL_TITLE",
+ data=final_title,
+ )
+
+ @broadcast_to_sockets
+ async def on_long_summary(self, data):
+ final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
+ async with self.transaction():
+ transcript = await self.get_transcript()
+ await transcripts_controller.update(
+ transcript,
+ {
+ "long_summary": final_long_summary.long_summary,
+ },
+ )
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="FINAL_LONG_SUMMARY",
+ data=final_long_summary,
+ )
+
+ @broadcast_to_sockets
+ async def on_short_summary(self, data):
+ final_short_summary = TranscriptFinalShortSummary(
+ short_summary=data.short_summary
+ )
+ async with self.transaction():
+ transcript = await self.get_transcript()
+ await transcripts_controller.update(
+ transcript,
+ {
+ "short_summary": final_short_summary.short_summary,
+ },
+ )
+ return await transcripts_controller.append_event(
+ transcript=transcript,
+ event="FINAL_SHORT_SUMMARY",
+ data=final_short_summary,
+ )
+
+
+class PipelineMainLive(PipelineMainBase):
+ audio_filename: Path | None = None
+ source_language: str = "en"
+ target_language: str = "en"
+
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
@@ -125,96 +242,49 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
+ TranscriptFinalLongSummaryProcessor.as_threaded(
+ callback=self.on_long_summary
+ ),
+ TranscriptFinalShortSummaryProcessor.as_threaded(
+ callback=self.on_short_summary
+ ),
+ ]
+ ),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
- # when the pipeline ends, connect to the post pipeline
- async def on_ended():
- task_pipeline_main_post.delay(transcript_id=self.transcript_id)
-
- pipeline.on_ended = self
-
return pipeline
+ async def on_ended(self):
+ # when the pipeline ends, connect to the post pipeline
+ task_pipeline_main_post.delay(transcript_id=self.transcript_id)
-class PipelineMainPost(PipelineMainBase):
+
+class PipelineMainDiarization(PipelineMainBase):
"""
- Implement the rest of the main pipeline, triggered after PipelineMainLive ended.
+ Diarization is a long time process, so we do it in a separate pipeline
+ When done, adjust the short and final summary
"""
- @broadcast_to_socket
- async def on_final_title(self, data):
- final_title = TranscriptFinalTitle(title=data.title)
- async with transcripts_controller.transaction():
- transcript = await self.get_transcript()
- if not transcript.title:
- transcripts_controller.update(
- self.transcript,
- {
- "title": final_title.title,
- },
- )
- return await transcripts_controller.append_event(
- transcript=transcript,
- event="FINAL_TITLE",
- data=final_title,
- )
-
- @broadcast_to_socket
- async def on_final_long_summary(self, data):
- final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
- async with transcripts_controller.transaction():
- transcript = await self.get_transcript()
- await transcripts_controller.update(
- transcript,
- {
- "long_summary": final_long_summary.long_summary,
- },
- )
- return await transcripts_controller.append_event(
- transcript=transcript,
- event="FINAL_LONG_SUMMARY",
- data=final_long_summary,
- )
-
- @broadcast_to_socket
- async def on_final_short_summary(self, data):
- final_short_summary = TranscriptFinalShortSummary(
- short_summary=data.short_summary
- )
- async with transcripts_controller.transaction():
- transcript = await self.get_transcript()
- await transcripts_controller.update(
- transcript,
- {
- "short_summary": final_short_summary.short_summary,
- },
- )
- return await transcripts_controller.append_event(
- transcript=transcript,
- event="FINAL_SHORT_SUMMARY",
- data=final_short_summary,
- )
-
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
processors = [
- # add diarization
+ AudioDiarizationProcessor(),
BroadcastProcessor(
processors=[
- TranscriptFinalTitleProcessor.as_threaded(
- callback=self.on_final_title
- ),
TranscriptFinalLongSummaryProcessor.as_threaded(
- callback=self.on_final_long_summary
+ callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
- callback=self.on_final_short_summary
+ callback=self.on_short_summary
),
]
),
@@ -222,9 +292,35 @@ class PipelineMainPost(PipelineMainBase):
pipeline = Pipeline(*processors)
pipeline.options = self
+ # now let's start the pipeline by pushing information to the
+ # first processor diarization processor
+ # XXX translation is lost when converting our data model to the processor model
+ transcript = await self.get_transcript()
+ topics = [
+ TitleSummaryProcessorType(
+ title=topic.title,
+ summary=topic.summary,
+ timestamp=topic.timestamp,
+ duration=topic.duration,
+ transcript=TranscriptProcessorType(words=topic.words),
+ )
+ for topic in transcript.topics
+ ]
+
+ audio_diarization_input = AudioDiarizationInput(
+ audio_filename=transcript.audio_mp3_filename,
+ topics=topics,
+ )
+
+ # as tempting to use pipeline.push, prefer to use the runner
+ # to let the start just do one job.
+ self.push(audio_diarization_input)
+ self.flush()
+
return pipeline
-@celery.task
+@shared_task
def task_pipeline_main_post(transcript_id: str):
- pass
+ runner = PipelineMainDiarization(transcript_id=transcript_id)
+ runner.start_sync()
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
index ce84fec4..0575cf96 100644
--- a/server/reflector/pipelines/runner.py
+++ b/server/reflector/pipelines/runner.py
@@ -16,7 +16,6 @@ During its lifecycle, it will emit the following status:
"""
import asyncio
-from typing import Callable
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
@@ -27,8 +26,6 @@ class PipelineRunner(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
- on_status: Callable | None = None
- on_ended: Callable | None = None
pipeline: Pipeline | None = None
def __init__(self, **kwargs):
@@ -36,6 +33,10 @@ class PipelineRunner(BaseModel):
self._q_cmd = asyncio.Queue()
self._ev_done = asyncio.Event()
self._is_first_push = True
+ self._logger = logger.bind(
+ runner=id(self),
+ runner_cls=self.__class__.__name__,
+ )
def create(self) -> Pipeline:
"""
@@ -50,33 +51,51 @@ class PipelineRunner(BaseModel):
"""
asyncio.get_event_loop().create_task(self.run())
- async def push(self, data):
+ def start_sync(self):
+ """
+ Start the pipeline synchronously (for non-asyncio apps)
+ """
+ asyncio.run(self.run())
+
+ def push(self, data):
"""
Push data to the pipeline
"""
- await self._add_cmd("PUSH", data)
+ self._add_cmd("PUSH", data)
- async def flush(self):
+ def flush(self):
"""
Flush the pipeline
"""
- await self._add_cmd("FLUSH", None)
+ self._add_cmd("FLUSH", None)
- async def _add_cmd(self, cmd: str, data):
+ async def on_status(self, status):
+ """
+ Called when the status of the pipeline changes
+ """
+ pass
+
+ async def on_ended(self):
+ """
+ Called when the pipeline ends
+ """
+ pass
+
+ def _add_cmd(self, cmd: str, data):
"""
Enqueue a command to be executed in the runner.
Currently supported commands: PUSH, FLUSH
"""
- await self._q_cmd.put([cmd, data])
+ self._q_cmd.put_nowait([cmd, data])
async def _set_status(self, status):
- print("set_status", status)
+ self._logger.debug("Runner status updated", status=status)
self.status = status
if self.on_status:
try:
await self.on_status(status)
- except Exception as e:
- logger.error("PipelineRunner status_callback error", error=e)
+ except Exception:
+ self._logger.exception("Runer error while setting status")
async def run(self):
try:
@@ -95,8 +114,8 @@ class PipelineRunner(BaseModel):
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
- except Exception as e:
- logger.error("PipelineRunner error", error=e)
+ except Exception:
+ self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
if self.on_ended:
diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py
index 960c6a35..01a3a174 100644
--- a/server/reflector/processors/__init__.py
+++ b/server/reflector/processors/__init__.py
@@ -1,4 +1,5 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401
+from .audio_diarization import AudioDiarizationProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py
new file mode 100644
index 00000000..8db8e8e5
--- /dev/null
+++ b/server/reflector/processors/audio_diarization.py
@@ -0,0 +1,65 @@
+from reflector.processors.base import Processor
+from reflector.processors.types import AudioDiarizationInput, TitleSummary
+
+
+class AudioDiarizationProcessor(Processor):
+ INPUT_TYPE = AudioDiarizationInput
+ OUTPUT_TYPE = TitleSummary
+
+ async def _push(self, data: AudioDiarizationInput):
+ # Gather diarization data
+ diarization = [
+ {"start": 0.0, "stop": 4.9, "speaker": 2},
+ {"start": 5.6, "stop": 6.7, "speaker": 2},
+ {"start": 7.3, "stop": 8.9, "speaker": 2},
+ {"start": 7.3, "stop": 7.9, "speaker": 0},
+ {"start": 9.4, "stop": 11.2, "speaker": 2},
+ {"start": 9.7, "stop": 10.0, "speaker": 0},
+ {"start": 10.0, "stop": 10.1, "speaker": 0},
+ {"start": 11.7, "stop": 16.1, "speaker": 2},
+ {"start": 11.8, "stop": 12.1, "speaker": 1},
+ {"start": 16.4, "stop": 21.0, "speaker": 2},
+ {"start": 21.1, "stop": 22.6, "speaker": 2},
+ {"start": 24.7, "stop": 31.9, "speaker": 2},
+ {"start": 32.0, "stop": 32.8, "speaker": 1},
+ {"start": 33.4, "stop": 37.8, "speaker": 2},
+ {"start": 37.9, "stop": 40.3, "speaker": 0},
+ {"start": 39.2, "stop": 40.4, "speaker": 2},
+ {"start": 40.7, "stop": 41.4, "speaker": 0},
+ {"start": 41.6, "stop": 45.7, "speaker": 2},
+ {"start": 46.4, "stop": 53.1, "speaker": 2},
+ {"start": 53.6, "stop": 56.5, "speaker": 2},
+ {"start": 54.9, "stop": 75.4, "speaker": 1},
+ {"start": 57.3, "stop": 58.0, "speaker": 2},
+ {"start": 65.7, "stop": 66.0, "speaker": 2},
+ {"start": 75.8, "stop": 78.8, "speaker": 1},
+ {"start": 79.0, "stop": 82.6, "speaker": 1},
+ {"start": 83.2, "stop": 83.3, "speaker": 1},
+ {"start": 84.5, "stop": 94.3, "speaker": 1},
+ {"start": 95.1, "stop": 100.7, "speaker": 1},
+ {"start": 100.7, "stop": 102.0, "speaker": 0},
+ {"start": 100.7, "stop": 101.8, "speaker": 1},
+ {"start": 102.0, "stop": 103.0, "speaker": 1},
+ {"start": 103.0, "stop": 103.7, "speaker": 0},
+ {"start": 103.7, "stop": 103.8, "speaker": 1},
+ {"start": 103.8, "stop": 113.9, "speaker": 0},
+ {"start": 114.7, "stop": 117.0, "speaker": 0},
+ {"start": 117.0, "stop": 117.4, "speaker": 1},
+ ]
+
+ # now reapply speaker to topics (if any)
+ # topics is a list[BaseModel] with an attribute words
+ # words is a list[BaseModel] with text, start and speaker attribute
+
+ print("IN DIARIZATION PROCESSOR", data)
+
+ # mutate in place
+ for topic in data.topics:
+ for word in topic.transcript.words:
+ for d in diarization:
+ if d["start"] <= word.start <= d["stop"]:
+ word.speaker = d["speaker"]
+
+ # emit them
+ for topic in data.topics:
+ await self.emit(topic)
diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py
index d2c32d17..3ec21491 100644
--- a/server/reflector/processors/types.py
+++ b/server/reflector/processors/types.py
@@ -382,3 +382,8 @@ class TranslationLanguages(BaseModel):
def is_supported(self, lang_id: str) -> bool:
return lang_id in self.supported_languages
+
+
+class AudioDiarizationInput(BaseModel):
+ audio_filename: Path
+ topics: list[TitleSummary]
diff --git a/server/reflector/tasks/boot.py b/server/reflector/tasks/boot.py
deleted file mode 100644
index 88cc2d6f..00000000
--- a/server/reflector/tasks/boot.py
+++ /dev/null
@@ -1,2 +0,0 @@
-import reflector.tasks.post_transcript # noqa
-import reflector.tasks.worker # noqa
diff --git a/server/reflector/tasks/worker.py b/server/reflector/tasks/worker.py
deleted file mode 100644
index 4379a1b7..00000000
--- a/server/reflector/tasks/worker.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from celery import Celery
-from reflector.settings import settings
-
-celery = Celery(__name__)
-celery.conf.broker_url = settings.CELERY_BROKER_URL
-celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py
index 5d10c181..386ada9c 100644
--- a/server/reflector/views/rtc_offer.py
+++ b/server/reflector/views/rtc_offer.py
@@ -1,5 +1,4 @@
import asyncio
-from enum import StrEnum
from json import loads
import av
@@ -41,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack):
ctx = self.ctx
frame = await self.track.recv()
try:
- await ctx.pipeline_runner.push(frame)
+ ctx.pipeline_runner.push(frame)
except Exception as e:
ctx.logger.error("Pipeline error", error=e)
return frame
@@ -52,19 +51,6 @@ class RtcOffer(BaseModel):
type: str
-class StrValue(BaseModel):
- value: str
-
-
-class PipelineEvent(StrEnum):
- TRANSCRIPT = "TRANSCRIPT"
- TOPIC = "TOPIC"
- FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
- STATUS = "STATUS"
- FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
- FINAL_TITLE = "FINAL_TITLE"
-
-
async def rtc_offer_base(
params: RtcOffer,
request: Request,
@@ -90,7 +76,7 @@ async def rtc_offer_base(
# - when we receive the close event, we do nothing.
# 2. or the client close the connection
# and there is nothing to do because it is already closed
- await ctx.pipeline_runner.flush()
+ ctx.pipeline_runner.flush()
if close:
ctx.logger.debug("Closing peer connection")
await pc.close()
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index e949d645..31cbe28e 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -23,10 +23,9 @@ from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
-from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
+from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
-ws_manager = get_ws_manager()
# ==============================================================
# Transcripts list
@@ -166,32 +165,17 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
- values = {"events": []}
+ values = {}
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.long_summary is not None:
values["long_summary"] = info.long_summary
- for transcript_event in transcript.events:
- if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
- transcript_event["long_summary"] = info.long_summary
- break
- values["events"].extend(transcript.events)
if info.short_summary is not None:
values["short_summary"] = info.short_summary
- for transcript_event in transcript.events:
- if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
- transcript_event["short_summary"] = info.short_summary
- break
- values["events"].extend(transcript.events)
if info.title is not None:
values["title"] = info.title
- for transcript_event in transcript.events:
- if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
- transcript_event["title"] = info.title
- break
- values["events"].extend(transcript.events)
await transcripts_controller.update(transcript, values)
return transcript
@@ -295,6 +279,7 @@ async def transcript_events_websocket(
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
+ ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
@@ -303,9 +288,7 @@ async def transcript_events_websocket(
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
- if name == PipelineEvent.TRANSCRIPT:
- continue
- if name == PipelineEvent.STATUS:
+ if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py
new file mode 100644
index 00000000..3714a64d
--- /dev/null
+++ b/server/reflector/worker/app.py
@@ -0,0 +1,11 @@
+from celery import Celery
+from reflector.settings import settings
+
+app = Celery(__name__)
+app.conf.broker_url = settings.CELERY_BROKER_URL
+app.conf.result_backend = settings.CELERY_RESULT_BACKEND
+app.autodiscover_tasks(
+ [
+ "reflector.pipelines.main_live_pipeline",
+ ]
+)
diff --git a/server/reflector/tasks/post_transcript.py b/server/reflector/worker/post_transcript.py
similarity index 100%
rename from server/reflector/tasks/post_transcript.py
rename to server/reflector/worker/post_transcript.py
diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py
index 7650807b..a84e3361 100644
--- a/server/reflector/ws_manager.py
+++ b/server/reflector/ws_manager.py
@@ -11,13 +11,12 @@ broadcast messages to all connected websockets.
import asyncio
import json
+import threading
import redis.asyncio as redis
from fastapi import WebSocket
from reflector.settings import settings
-ws_manager = None
-
class RedisPubSubManager:
def __init__(self, host="localhost", port=6379):
@@ -114,13 +113,14 @@ def get_ws_manager() -> WebsocketManager:
ImportError: If the 'reflector.settings' module cannot be imported.
RedisConnectionError: If there is an error connecting to the Redis server.
"""
- global ws_manager
- if ws_manager:
- return ws_manager
+ local = threading.local()
+ if hasattr(local, "ws_manager"):
+ return local.ws_manager
pubsub_client = RedisPubSubManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
)
ws_manager = WebsocketManager(pubsub_client=pubsub_client)
+ local.ws_manager = ws_manager
return ws_manager
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index 76b56abf..d5f5f0b9 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -45,17 +45,16 @@ async def dummy_transcript():
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
- async def _transcript(self, data: AudioFile):
- source_language = self.get_pref("audio:source_language", "en")
- print("transcripting", source_language)
- print("pipeline", self.pipeline)
- print("prefs", self.pipeline.prefs)
+ _time_idx = 0
+ async def _transcript(self, data: AudioFile):
+ i = self._time_idx
+ self._time_idx += 2
return Transcript(
text="Hello world.",
words=[
- Word(start=0.0, end=1.0, text="Hello"),
- Word(start=1.0, end=2.0, text=" world."),
+ Word(start=i, end=i + 1, text="Hello", speaker=0),
+ Word(start=i + 1, end=i + 2, text=" world.", speaker=0),
],
)
@@ -98,7 +97,17 @@ def ensure_casing():
@pytest.fixture
def sentence_tokenize():
with patch(
- "reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
+ "reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize"
) as mock_sent_tokenize:
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
yield
+
+
+@pytest.fixture(scope="session")
+def celery_enable_logging():
+ return True
+
+
+@pytest.fixture(scope="session")
+def celery_config():
+ return {"broker_url": "memory://", "result_backend": "rpc"}
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index 50e74231..e2bfee32 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -32,7 +32,7 @@ class ThreadedUvicorn:
@pytest.fixture
-async def appserver(tmpdir):
+async def appserver(tmpdir, celery_session_app, celery_session_worker):
from reflector.settings import settings
from reflector.app import app
@@ -52,6 +52,13 @@ async def appserver(tmpdir):
settings.DATA_DIR = DATA_DIR
+@pytest.fixture(scope="session")
+def celery_includes():
+ return ["reflector.pipelines.main_live_pipeline"]
+
+
+@pytest.mark.usefixtures("celery_session_app")
+@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
@@ -121,14 +128,20 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
-
- # wait the processing to finish
- await asyncio.sleep(2)
-
await client.stop()
# wait the processing to finish
- await asyncio.sleep(2)
+ timeout = 20
+ while True:
+ # fetch the transcript and check if it is ended
+ resp = await ac.get(f"/transcripts/{tid}")
+ assert resp.status_code == 200
+ if resp.json()["status"] in ("ended", "error"):
+ break
+ await asyncio.sleep(1)
+
+ if resp.json()["status"] != "ended":
+ raise TimeoutError("Timeout while waiting for transcript to be ended")
# stop websocket task
websocket_task.cancel()
@@ -152,7 +165,7 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
- assert ev["data"]["transcript"].startswith("Hello world.")
+ assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -169,23 +182,21 @@ async def test_transcript_rtc_and_websocket(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
- assert statuses == ["recording", "processing", "ended"]
+ assert statuses.index("recording") < statuses.index("processing")
+ assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
- # check that transcript status in model is updated
- resp = await ac.get(f"/transcripts/{tid}")
- assert resp.status_code == 200
- assert resp.json()["status"] == "ended"
-
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg"
+@pytest.mark.usefixtures("celery_session_app")
+@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,
@@ -265,6 +276,18 @@ async def test_transcript_rtc_and_websocket_and_fr(
await client.stop()
# wait the processing to finish
+ timeout = 20
+ while True:
+ # fetch the transcript and check if it is ended
+ resp = await ac.get(f"/transcripts/{tid}")
+ assert resp.status_code == 200
+ if resp.json()["status"] == "ended":
+ break
+ await asyncio.sleep(1)
+
+ if resp.json()["status"] != "ended":
+ raise TimeoutError("Timeout while waiting for transcript to be ended")
+
await asyncio.sleep(2)
# stop websocket task
@@ -289,7 +312,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
- assert ev["data"]["transcript"].startswith("Hello world.")
+ assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -306,7 +329,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
- assert statuses == ["recording", "processing", "ended"]
+ assert statuses.index("recording") < statuses.index("processing")
+ assert statuses.index("processing") < statuses.index("ended")
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
From d8a842f099091ad1ad19b934a4ff1fadb3003a95 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 27 Oct 2023 20:00:07 +0200
Subject: [PATCH 16/41] server: full diarization processor implementation based
on gokul app
---
server/reflector/app.py | 3 +
.../reflector/pipelines/main_live_pipeline.py | 30 +++++----
server/reflector/pipelines/runner.py | 6 +-
server/reflector/processors/__init__.py | 2 +-
.../reflector/processors/audio_diarization.py | 65 -------------------
.../processors/audio_diarization_auto.py | 34 ++++++++++
.../processors/audio_diarization_base.py | 28 ++++++++
.../processors/audio_diarization_modal.py | 36 ++++++++++
.../processors/audio_transcript_auto.py | 34 ++--------
server/reflector/processors/types.py | 2 +-
server/reflector/settings.py | 10 +++
.../tools/start_post_main_live_pipeline.py | 14 ++++
server/reflector/views/transcripts.py | 29 ++++++++-
server/reflector/worker/app.py | 1 +
server/tests/test_transcripts_rtc_ws.py | 2 +
15 files changed, 186 insertions(+), 110 deletions(-)
delete mode 100644 server/reflector/processors/audio_diarization.py
create mode 100644 server/reflector/processors/audio_diarization_auto.py
create mode 100644 server/reflector/processors/audio_diarization_base.py
create mode 100644 server/reflector/processors/audio_diarization_modal.py
create mode 100644 server/reflector/tools/start_post_main_live_pipeline.py
diff --git a/server/reflector/app.py b/server/reflector/app.py
index 758faf69..c2e3bf7e 100644
--- a/server/reflector/app.py
+++ b/server/reflector/app.py
@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)
+# prepare celery
+from reflector.worker import app as celery_app # noqa
+
# simpler openapi id
def use_route_names_as_operation_ids(app: FastAPI) -> None:
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 4159c889..87e2ff46 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -13,10 +13,12 @@ It is directly linked to our data model.
import asyncio
from contextlib import asynccontextmanager
+from datetime import timedelta
from pathlib import Path
from celery import shared_task
from pydantic import BaseModel
+from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
@@ -29,7 +31,7 @@ from reflector.db.transcripts import (
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
- AudioDiarizationProcessor,
+ AudioDiarizationAutoProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
@@ -45,6 +47,7 @@ from reflector.processors import (
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
from reflector.processors.types import Transcript as TranscriptProcessorType
+from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
@@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner):
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
- transcripts_controller.update(
+ await transcripts_controller.update(
transcript,
{
"title": final_title.title,
@@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase):
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(),
AudioMergeProcessor(),
- AudioTranscriptAutoProcessor.as_threaded(),
+ AudioTranscriptAutoProcessor.get_instance().as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
- TranscriptFinalLongSummaryProcessor.as_threaded(
- callback=self.on_long_summary
- ),
- TranscriptFinalShortSummaryProcessor.as_threaded(
- callback=self.on_short_summary
- ),
]
),
]
@@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase):
# add a customised logger to the context
self.prepare()
processors = [
- AudioDiarizationProcessor(),
+ AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase):
for topic in transcript.topics
]
+ # we need to create an url to be used for diarization
+ # we can't use the audio_mp3_filename because it's not accessible
+ # from the diarization processor
+ from reflector.views.transcripts import create_access_token
+
+ token = create_access_token(
+ {"sub": transcript.user_id},
+ expires_delta=timedelta(minutes=15),
+ )
+ path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
+ url = f"{settings.BASE_URL}{path}?token={token}"
audio_diarization_input = AudioDiarizationInput(
- audio_filename=transcript.audio_mp3_filename,
+ audio_url=url,
topics=topics,
)
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
index 0575cf96..583cdcb6 100644
--- a/server/reflector/pipelines/runner.py
+++ b/server/reflector/pipelines/runner.py
@@ -55,7 +55,11 @@ class PipelineRunner(BaseModel):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
- asyncio.run(self.run())
+ loop = asyncio.get_event_loop()
+ if not loop:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.run_until_complete(self.run())
def push(self, data):
"""
diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py
index 01a3a174..1c88d6c5 100644
--- a/server/reflector/processors/__init__.py
+++ b/server/reflector/processors/__init__.py
@@ -1,5 +1,5 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401
-from .audio_diarization import AudioDiarizationProcessor # noqa: F401
+from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py
deleted file mode 100644
index 8db8e8e5..00000000
--- a/server/reflector/processors/audio_diarization.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from reflector.processors.base import Processor
-from reflector.processors.types import AudioDiarizationInput, TitleSummary
-
-
-class AudioDiarizationProcessor(Processor):
- INPUT_TYPE = AudioDiarizationInput
- OUTPUT_TYPE = TitleSummary
-
- async def _push(self, data: AudioDiarizationInput):
- # Gather diarization data
- diarization = [
- {"start": 0.0, "stop": 4.9, "speaker": 2},
- {"start": 5.6, "stop": 6.7, "speaker": 2},
- {"start": 7.3, "stop": 8.9, "speaker": 2},
- {"start": 7.3, "stop": 7.9, "speaker": 0},
- {"start": 9.4, "stop": 11.2, "speaker": 2},
- {"start": 9.7, "stop": 10.0, "speaker": 0},
- {"start": 10.0, "stop": 10.1, "speaker": 0},
- {"start": 11.7, "stop": 16.1, "speaker": 2},
- {"start": 11.8, "stop": 12.1, "speaker": 1},
- {"start": 16.4, "stop": 21.0, "speaker": 2},
- {"start": 21.1, "stop": 22.6, "speaker": 2},
- {"start": 24.7, "stop": 31.9, "speaker": 2},
- {"start": 32.0, "stop": 32.8, "speaker": 1},
- {"start": 33.4, "stop": 37.8, "speaker": 2},
- {"start": 37.9, "stop": 40.3, "speaker": 0},
- {"start": 39.2, "stop": 40.4, "speaker": 2},
- {"start": 40.7, "stop": 41.4, "speaker": 0},
- {"start": 41.6, "stop": 45.7, "speaker": 2},
- {"start": 46.4, "stop": 53.1, "speaker": 2},
- {"start": 53.6, "stop": 56.5, "speaker": 2},
- {"start": 54.9, "stop": 75.4, "speaker": 1},
- {"start": 57.3, "stop": 58.0, "speaker": 2},
- {"start": 65.7, "stop": 66.0, "speaker": 2},
- {"start": 75.8, "stop": 78.8, "speaker": 1},
- {"start": 79.0, "stop": 82.6, "speaker": 1},
- {"start": 83.2, "stop": 83.3, "speaker": 1},
- {"start": 84.5, "stop": 94.3, "speaker": 1},
- {"start": 95.1, "stop": 100.7, "speaker": 1},
- {"start": 100.7, "stop": 102.0, "speaker": 0},
- {"start": 100.7, "stop": 101.8, "speaker": 1},
- {"start": 102.0, "stop": 103.0, "speaker": 1},
- {"start": 103.0, "stop": 103.7, "speaker": 0},
- {"start": 103.7, "stop": 103.8, "speaker": 1},
- {"start": 103.8, "stop": 113.9, "speaker": 0},
- {"start": 114.7, "stop": 117.0, "speaker": 0},
- {"start": 117.0, "stop": 117.4, "speaker": 1},
- ]
-
- # now reapply speaker to topics (if any)
- # topics is a list[BaseModel] with an attribute words
- # words is a list[BaseModel] with text, start and speaker attribute
-
- print("IN DIARIZATION PROCESSOR", data)
-
- # mutate in place
- for topic in data.topics:
- for word in topic.transcript.words:
- for d in diarization:
- if d["start"] <= word.start <= d["stop"]:
- word.speaker = d["speaker"]
-
- # emit them
- for topic in data.topics:
- await self.emit(topic)
diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py
new file mode 100644
index 00000000..1de19b45
--- /dev/null
+++ b/server/reflector/processors/audio_diarization_auto.py
@@ -0,0 +1,34 @@
+import importlib
+
+from reflector.processors.base import Processor
+from reflector.settings import settings
+
+
+class AudioDiarizationAutoProcessor(Processor):
+ _registry = {}
+
+ @classmethod
+ def register(cls, name, kclass):
+ cls._registry[name] = kclass
+
+ @classmethod
+ def get_instance(cls, name: str | None = None, **kwargs):
+ if name is None:
+ name = settings.DIARIZATION_BACKEND
+
+ if name not in cls._registry:
+ module_name = f"reflector.processors.audio_diarization_{name}"
+ importlib.import_module(module_name)
+
+ # gather specific configuration for the processor
+ # search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
+ config = {}
+ name_upper = name.upper()
+ settings_prefix = "DIARIZATION_"
+ config_prefix = f"{settings_prefix}{name_upper}_"
+ for key, value in settings:
+ if key.startswith(config_prefix):
+ config_name = key[len(settings_prefix) :].lower()
+ config[config_name] = value
+
+ return cls._registry[name](**config | kwargs)
diff --git a/server/reflector/processors/audio_diarization_base.py b/server/reflector/processors/audio_diarization_base.py
new file mode 100644
index 00000000..2ad7e4bf
--- /dev/null
+++ b/server/reflector/processors/audio_diarization_base.py
@@ -0,0 +1,28 @@
+from reflector.processors.base import Processor
+from reflector.processors.types import AudioDiarizationInput, TitleSummary
+
+
+class AudioDiarizationBaseProcessor(Processor):
+ INPUT_TYPE = AudioDiarizationInput
+ OUTPUT_TYPE = TitleSummary
+
+ async def _push(self, data: AudioDiarizationInput):
+ diarization = await self._diarize(data)
+
+ # now reapply speaker to topics (if any)
+ # topics is a list[BaseModel] with an attribute words
+ # words is a list[BaseModel] with text, start and speaker attribute
+
+ # mutate in place
+ for topic in data.topics:
+ for word in topic.transcript.words:
+ for d in diarization:
+ if d["start"] <= word.start <= d["end"]:
+ word.speaker = d["speaker"]
+
+ # emit them
+ for topic in data.topics:
+ await self.emit(topic)
+
+ async def _diarize(self, data: AudioDiarizationInput):
+ raise NotImplementedError
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
new file mode 100644
index 00000000..b71dbcc9
--- /dev/null
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -0,0 +1,36 @@
+import httpx
+from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
+from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
+from reflector.processors.types import AudioDiarizationInput, TitleSummary
+from reflector.settings import settings
+
+
+class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
+ INPUT_TYPE = AudioDiarizationInput
+ OUTPUT_TYPE = TitleSummary
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.diarization_url = settings.DIARIZATION_URL + "/diarize"
+ self.headers = {
+ "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
+ }
+
+ async def _diarize(self, data: AudioDiarizationInput):
+ # Gather diarization data
+ params = {
+ "audio_file_url": data.audio_url,
+ "timestamp": 0,
+ }
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ self.diarization_url,
+ headers=self.headers,
+ params=params,
+ timeout=None,
+ )
+ response.raise_for_status()
+ return response.json()["text"]
+
+
+AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py
index f223a52d..fc1f0b5e 100644
--- a/server/reflector/processors/audio_transcript_auto.py
+++ b/server/reflector/processors/audio_transcript_auto.py
@@ -1,8 +1,6 @@
import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor
-from reflector.processors.base import Pipeline, Processor
-from reflector.processors.types import AudioFile
from reflector.settings import settings
@@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
cls._registry[name] = kclass
@classmethod
- def get_instance(cls, name):
+ def get_instance(cls, name: str | None = None, **kwargs):
+ if name is None:
+ name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_transcript_{name}"
importlib.import_module(module_name)
@@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
- return cls._registry[name](**config)
-
- def __init__(self, **kwargs):
- self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
- super().__init__(**kwargs)
-
- def set_pipeline(self, pipeline: Pipeline):
- super().set_pipeline(pipeline)
- self.processor.set_pipeline(pipeline)
-
- def connect(self, processor: Processor):
- self.processor.connect(processor)
-
- def disconnect(self, processor: Processor):
- self.processor.disconnect(processor)
-
- def on(self, callback):
- self.processor.on(callback)
-
- def off(self, callback):
- self.processor.off(callback)
-
- async def _push(self, data: AudioFile):
- return await self.processor._push(data)
-
- async def _flush(self):
- return await self.processor._flush()
+ return cls._registry[name](**config | kwargs)
diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py
index 3ec21491..b67f84b9 100644
--- a/server/reflector/processors/types.py
+++ b/server/reflector/processors/types.py
@@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel):
class AudioDiarizationInput(BaseModel):
- audio_filename: Path
+ audio_url: str
topics: list[TitleSummary]
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index d7cc2c33..021d509f 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -89,6 +89,10 @@ class Settings(BaseSettings):
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
+ # Diarization
+ DIARIZATION_BACKEND: str = "modal"
+ DIARIZATION_URL: str | None = None
+
# Sentry
SENTRY_DSN: str | None = None
@@ -121,5 +125,11 @@ class Settings(BaseSettings):
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
+ # Secret key
+ SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
+
+ # Current hosting/domain
+ BASE_URL: str = "http://localhost:1250"
+
settings = Settings()
diff --git a/server/reflector/tools/start_post_main_live_pipeline.py b/server/reflector/tools/start_post_main_live_pipeline.py
new file mode 100644
index 00000000..859f03a4
--- /dev/null
+++ b/server/reflector/tools/start_post_main_live_pipeline.py
@@ -0,0 +1,14 @@
+import argparse
+
+from reflector.app import celery_app # noqa
+from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
+
+parser = argparse.ArgumentParser()
+parser.add_argument("transcript_id", type=str)
+parser.add_argument("--delay", action="store_true")
+args = parser.parse_args()
+
+if args.delay:
+ task_pipeline_main_post.delay(args.transcript_id)
+else:
+ task_pipeline_main_post(args.transcript_id)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 31cbe28e..f83bc6de 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -1,4 +1,4 @@
-from datetime import datetime
+from datetime import datetime, timedelta
from typing import Annotated, Optional
import reflector.auth as auth
@@ -9,8 +9,10 @@ from fastapi import (
Request,
WebSocket,
WebSocketDisconnect,
+ status,
)
from fastapi_pagination import Page, paginate
+from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
AudioWaveform,
@@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
+ALGORITHM = "HS256"
+DOWNLOAD_EXPIRE_MINUTES = 60
+
+
+def create_access_token(data: dict, expires_delta: timedelta):
+ to_encode = data.copy()
+ expire = datetime.utcnow() + expires_delta
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
+ return encoded_jwt
+
+
# ==============================================================
# Transcripts list
# ==============================================================
@@ -198,8 +212,21 @@ async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+ token: str | None = None,
):
user_id = user["sub"] if user else None
+ if not user_id and token:
+ unauthorized_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or expired token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ try:
+ payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
+ user_id: str = payload.get("sub")
+ except jwt.JWTError:
+ raise unauthorized_exception
+
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py
index 3714a64d..e1000364 100644
--- a/server/reflector/worker/app.py
+++ b/server/reflector/worker/app.py
@@ -4,6 +4,7 @@ from reflector.settings import settings
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
+app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index e2bfee32..5a9a404b 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
+ print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
@@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
+ print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
From e405ccb8f38c6033cce17119047cc77ae84a8eac Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 27 Oct 2023 20:08:00 +0200
Subject: [PATCH 17/41] server: started updating documentation
---
README.md | 11 +++++++++--
docker-compose.yml | 14 +++++++++++---
server/docker-compose.yml | 28 ++++++++++++++++++----------
server/runserver.sh | 9 ++++++++-
4 files changed, 46 insertions(+), 16 deletions(-)
diff --git a/README.md b/README.md
index 22651cc6..150c4575 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@ The project architecture consists of three primary components:
* **Front-End**: NextJS React project hosted on Vercel, located in `www/`.
* **Back-End**: Python server that offers an API and data persistence, found in `server/`.
-* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
+* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations.
It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end.
@@ -120,6 +120,9 @@ TRANSCRIPT_MODAL_API_KEY=
LLM_BACKEND=modal
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=
+TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
+ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
+DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
AUTH_BACKEND=fief
AUTH_FIEF_URL=https://auth.reflector.media/reflector-local
@@ -138,6 +141,10 @@ Use:
poetry run python3 -m reflector.app
```
+And start the background worker
+
+celery -A reflector.worker.app worker --loglevel=info
+
#### Using docker
Use:
@@ -161,4 +168,4 @@ poetry run python -m reflector.tools.process path/to/audio.wav
## AI Models
-*(Documentation for this section is pending.)*
\ No newline at end of file
+*(Documentation for this section is pending.)*
diff --git a/docker-compose.yml b/docker-compose.yml
index 934baaac..9e6519af 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -5,10 +5,19 @@ services:
context: server
ports:
- 1250:1250
- environment:
- LLM_URL: "${LLM_URL}"
volumes:
- model-cache:/root/.cache
+ environment: ENTRYPOINT=server
+ worker:
+ build:
+ context: server
+ volumes:
+ - model-cache:/root/.cache
+ environment: ENTRYPOINT=worker
+ redis:
+ image: redis:7.2
+ ports:
+ - 6379:6379
web:
build:
context: www
@@ -17,4 +26,3 @@ services:
volumes:
model-cache:
-
diff --git a/server/docker-compose.yml b/server/docker-compose.yml
index 4e5a21e8..c8432816 100644
--- a/server/docker-compose.yml
+++ b/server/docker-compose.yml
@@ -1,15 +1,23 @@
version: "3.9"
services:
- # server:
- # build:
- # context: .
- # ports:
- # - 1250:1250
- # environment:
- # LLM_URL: "${LLM_URL}"
- # MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}"
- # volumes:
- # - model-cache:/root/.cache
+ server:
+ build:
+ context: .
+ ports:
+ - 1250:1250
+ volumes:
+ - model-cache:/root/.cache
+ environment:
+ ENTRYPOINT: server
+ REDIS_HOST: redis
+ worker:
+ build:
+ context: .
+ volumes:
+ - model-cache:/root/.cache
+ environment:
+ ENTRYPOINT: worker
+ REDIS_HOST: redis
redis:
image: redis:7.2
ports:
diff --git a/server/runserver.sh b/server/runserver.sh
index 38eafe09..b0c3f138 100755
--- a/server/runserver.sh
+++ b/server/runserver.sh
@@ -4,4 +4,11 @@ if [ -f "/venv/bin/activate" ]; then
source /venv/bin/activate
fi
alembic upgrade head
-python -m reflector.app
+
+if [ "${ENTRYPOINT}" = "server" ]; then
+ python -m reflector.app
+elif [ "${ENTRYPOINT}" = "worker" ]; then
+ celery -A reflector.worker.app worker --loglevel=info
+else
+ echo "Unknown command"
+fi
From d0057ae2c4cc29167cca12ff65ed0415d7e64ee0 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 1 Nov 2023 10:28:15 +0100
Subject: [PATCH 18/41] server: add missing python-jose
---
server/poetry.lock | 67 ++++++++++++++++++++++++++++++++++++++++++-
server/pyproject.toml | 1 +
2 files changed, 67 insertions(+), 1 deletion(-)
diff --git a/server/poetry.lock b/server/poetry.lock
index 8783625b..e72ade57 100644
--- a/server/poetry.lock
+++ b/server/poetry.lock
@@ -1110,6 +1110,24 @@ idna = ["idna (>=2.1,<4.0)"]
trio = ["trio (>=0.14,<0.23)"]
wmi = ["wmi (>=1.5.1,<2.0.0)"]
+[[package]]
+name = "ecdsa"
+version = "0.18.0"
+description = "ECDSA cryptographic signature library (pure python)"
+optional = false
+python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
+files = [
+ {file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"},
+ {file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"},
+]
+
+[package.dependencies]
+six = ">=1.9.0"
+
+[package.extras]
+gmpy = ["gmpy"]
+gmpy2 = ["gmpy2"]
+
[[package]]
name = "fastapi"
version = "0.100.1"
@@ -2348,6 +2366,17 @@ files = [
{file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"},
]
+[[package]]
+name = "pyasn1"
+version = "0.5.0"
+description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+files = [
+ {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"},
+ {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"},
+]
+
[[package]]
name = "pycparser"
version = "2.21"
@@ -2754,6 +2783,28 @@ files = [
[package.extras]
cli = ["click (>=5.0)"]
+[[package]]
+name = "python-jose"
+version = "3.3.0"
+description = "JOSE implementation in Python"
+optional = false
+python-versions = "*"
+files = [
+ {file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"},
+ {file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"},
+]
+
+[package.dependencies]
+cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"cryptography\""}
+ecdsa = "!=0.15"
+pyasn1 = "*"
+rsa = "*"
+
+[package.extras]
+cryptography = ["cryptography (>=3.4.0)"]
+pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
+pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
+
[[package]]
name = "pyyaml"
version = "6.0.1"
@@ -3069,6 +3120,20 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+[[package]]
+name = "rsa"
+version = "4.9"
+description = "Pure-Python RSA implementation"
+optional = false
+python-versions = ">=3.6,<4"
+files = [
+ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
+ {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.1.3"
+
[[package]]
name = "s3transfer"
version = "0.6.2"
@@ -4078,4 +4143,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d"
+content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe"
diff --git a/server/pyproject.toml b/server/pyproject.toml
index c8614006..7681af39 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -35,6 +35,7 @@ protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
celery = "^5.3.4"
redis = "^5.0.1"
+python-jose = {extras = ["cryptography"], version = "^3.3.0"}
[tool.poetry.group.dev.dependencies]
From 4da890b95fc2952e3ad8b679036e9a88632ea09a Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 1 Nov 2023 11:55:46 +0100
Subject: [PATCH 19/41] server: add dummy diarization and fixes instanciation
---
.../reflector/pipelines/main_live_pipeline.py | 25 ++++++++++++++++---
server/reflector/pipelines/runner.py | 7 ++----
...arization_base.py => audio_diarization.py} | 2 +-
.../processors/audio_diarization_auto.py | 7 +++---
.../processors/audio_diarization_modal.py | 4 +--
.../processors/audio_transcript_auto.py | 3 +--
server/tests/conftest.py | 25 ++++++++++++++++++-
server/tests/test_transcripts_rtc_ws.py | 2 ++
8 files changed, 57 insertions(+), 18 deletions(-)
rename server/reflector/processors/{audio_diarization_base.py => audio_diarization.py} (95%)
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 87e2ff46..88e1bffd 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -28,6 +28,7 @@ from reflector.db.transcripts import (
TranscriptTopic,
transcripts_controller,
)
+from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
@@ -241,7 +242,7 @@ class PipelineMainLive(PipelineMainBase):
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(),
AudioMergeProcessor(),
- AudioTranscriptAutoProcessor.get_instance().as_threaded(),
+ AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
@@ -255,11 +256,18 @@ class PipelineMainLive(PipelineMainBase):
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
+ pipeline.logger.bind(transcript_id=transcript.id)
+ pipeline.logger.info(
+ "Pipeline main live created",
+ transcript_id=self.transcript_id,
+ )
return pipeline
async def on_ended(self):
# when the pipeline ends, connect to the post pipeline
+ logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
+ logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id)
@@ -274,7 +282,7 @@ class PipelineMainDiarization(PipelineMainBase):
# add a customised logger to the context
self.prepare()
processors = [
- AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
+ AudioDiarizationAutoProcessor(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -313,7 +321,10 @@ class PipelineMainDiarization(PipelineMainBase):
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
- path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
+ path = app.url_path_for(
+ "transcript_get_audio_mp3",
+ transcript_id=transcript.id,
+ )
url = f"{settings.BASE_URL}{path}?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_url=url,
@@ -322,6 +333,10 @@ class PipelineMainDiarization(PipelineMainBase):
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
+ pipeline.logger.bind(transcript_id=transcript.id)
+ pipeline.logger.info(
+ "Pipeline main post created", transcript_id=self.transcript_id
+ )
self.push(audio_diarization_input)
self.flush()
@@ -330,5 +345,9 @@ class PipelineMainDiarization(PipelineMainBase):
@shared_task
def task_pipeline_main_post(transcript_id: str):
+ logger.info(
+ "Starting main post pipeline",
+ transcript_id=transcript_id,
+ )
runner = PipelineMainDiarization(transcript_id=transcript_id)
runner.start_sync()
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
index 583cdcb6..a1e137a7 100644
--- a/server/reflector/pipelines/runner.py
+++ b/server/reflector/pipelines/runner.py
@@ -55,11 +55,8 @@ class PipelineRunner(BaseModel):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
- loop = asyncio.get_event_loop()
- if not loop:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(self.run())
+ coro = self.run()
+ asyncio.run(coro)
def push(self, data):
"""
diff --git a/server/reflector/processors/audio_diarization_base.py b/server/reflector/processors/audio_diarization.py
similarity index 95%
rename from server/reflector/processors/audio_diarization_base.py
rename to server/reflector/processors/audio_diarization.py
index 2ad7e4bf..d69f4b80 100644
--- a/server/reflector/processors/audio_diarization_base.py
+++ b/server/reflector/processors/audio_diarization.py
@@ -2,7 +2,7 @@ from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
-class AudioDiarizationBaseProcessor(Processor):
+class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py
index 1de19b45..0e7bfc5c 100644
--- a/server/reflector/processors/audio_diarization_auto.py
+++ b/server/reflector/processors/audio_diarization_auto.py
@@ -1,18 +1,17 @@
import importlib
-from reflector.processors.base import Processor
+from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.settings import settings
-class AudioDiarizationAutoProcessor(Processor):
+class AudioDiarizationAutoProcessor(AudioDiarizationProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
- @classmethod
- def get_instance(cls, name: str | None = None, **kwargs):
+ def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
index b71dbcc9..52be7c5d 100644
--- a/server/reflector/processors/audio_diarization_modal.py
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -1,11 +1,11 @@
import httpx
+from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
-from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
from reflector.settings import settings
-class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
+class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py
index fc1f0b5e..ac79ced0 100644
--- a/server/reflector/processors/audio_transcript_auto.py
+++ b/server/reflector/processors/audio_transcript_auto.py
@@ -11,8 +11,7 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
def register(cls, name, kclass):
cls._registry[name] = kclass
- @classmethod
- def get_instance(cls, name: str | None = None, **kwargs):
+ def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index d5f5f0b9..aafca9fd 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -60,12 +60,35 @@ async def dummy_transcript():
with patch(
"reflector.processors.audio_transcript_auto"
- ".AudioTranscriptAutoProcessor.get_instance"
+ ".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
+@pytest.fixture
+async def dummy_diarization():
+ from reflector.processors.audio_diarization import AudioDiarizationProcessor
+
+ class TestAudioDiarizationProcessor(AudioDiarizationProcessor):
+ _time_idx = 0
+
+ async def _diarize(self, data):
+ i = self._time_idx
+ self._time_idx += 2
+ return [
+ {"start": i, "end": i + 1, "speaker": 0},
+ {"start": i + 1, "end": i + 2, "speaker": 1},
+ ]
+
+ with patch(
+ "reflector.processors.audio_diarization_auto"
+ ".AudioDiarizationAutoProcessor.__new__"
+ ) as mock_audio:
+ mock_audio.return_value = TestAudioDiarizationProcessor()
+ yield
+
+
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index 5a9a404b..8f8cac71 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -65,6 +65,7 @@ async def test_transcript_rtc_and_websocket(
dummy_llm,
dummy_transcript,
dummy_processors,
+ dummy_diarization,
ensure_casing,
appserver,
sentence_tokenize,
@@ -204,6 +205,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_llm,
dummy_transcript,
dummy_processors,
+ dummy_diarization,
ensure_casing,
appserver,
sentence_tokenize,
From 3e7031d031567d46d77e316a4f40e51cf8dcaf85 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 1 Nov 2023 11:58:15 +0100
Subject: [PATCH 20/41] server: do not remove empty or recording transcripts by
default
We should have the possibility to delete or hide them later
---
server/reflector/db/transcripts.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 61a2c380..89025d53 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -167,8 +167,8 @@ class TranscriptController:
self,
user_id: str | None = None,
order_by: str | None = None,
- filter_empty: bool | None = True,
- filter_recording: bool | None = True,
+ filter_empty: bool | None = False,
+ filter_recording: bool | None = False,
) -> list[Transcript]:
"""
Get all transcripts
From 7fca7ae287b9131f21a3bb06cfbf114fd3441dc6 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Wed, 1 Nov 2023 12:00:33 +0100
Subject: [PATCH 21/41] ci: add redis service required for celery
---
.github/workflows/test_server.yml | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml
index 71bed5d0..9f3b9a6a 100644
--- a/.github/workflows/test_server.yml
+++ b/.github/workflows/test_server.yml
@@ -11,6 +11,11 @@ on:
jobs:
pytest:
runs-on: ubuntu-latest
+ services:
+ redis:
+ image: redis:6
+ ports:
+ - 6379:6379
steps:
- uses: actions/checkout@v3
- name: Install poetry
From dbf3c9fd2cfb845a3161394a2460a8f73fb83c2c Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 16:40:48 +0100
Subject: [PATCH 22/41] www: fix path
---
www/app/[domain]/transcripts/topicList.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx
index ef4a2889..56b02e6e 100644
--- a/www/app/[domain]/transcripts/topicList.tsx
+++ b/www/app/[domain]/transcripts/topicList.tsx
@@ -7,7 +7,7 @@ import {
import { formatTime } from "../../lib/time";
import ScrollToBottom from "./scrollToBottom";
import { Topic } from "./webSocketTypes";
-import { generateHighContrastColor } from "../lib/utils";
+import { generateHighContrastColor } from "../../lib/utils";
type TopicListProps = {
topics: Topic[];
From f5ce3dd75eb48cd8ececd4a34aea8afc887271f8 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 16:41:06 +0100
Subject: [PATCH 23/41] www: remove auth changes (will be done by sara PR)
---
.../transcripts/[transcriptId]/record/page.tsx | 15 +++------------
1 file changed, 3 insertions(+), 12 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
index 51a318a4..8e31327c 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx
@@ -14,8 +14,6 @@ import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faGear } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
-import { featRequireLogin } from "../../../../../app/lib/utils";
-import { useFiefIsAuthenticated } from "@fief/fief/nextjs/react";
type TranscriptDetails = {
params: {
@@ -38,23 +36,16 @@ const TranscriptRecord = (details: TranscriptDetails) => {
}
}, []);
- const isAuthenticated = useFiefIsAuthenticated();
const api = getApi();
- const [transcriptId, setTranscriptId] = useState("");
- const transcript = useTranscript(api, transcriptId);
- const webRTC = useWebRTC(stream, transcriptId, api);
- const webSockets = useWebSockets(transcriptId);
+ const transcript = useTranscript(api, details.params.transcriptId);
+ const webRTC = useWebRTC(stream, details.params.transcriptId, api);
+ const webSockets = useWebSockets(details.params.transcriptId);
const { audioDevices, getAudioStream } = useAudioDevice();
const [hasRecorded, setHasRecorded] = useState(false);
const [transcriptStarted, setTranscriptStarted] = useState(false);
- useEffect(() => {
- if (featRequireLogin() && !isAuthenticated) return;
- setTranscriptId(details.params.transcriptId);
- }, [api, details.params.transcriptId, isAuthenticated]);
-
useEffect(() => {
if (!transcriptStarted && webSockets.transcriptText.length !== 0)
setTranscriptStarted(true);
From bba5643237141a711c7a47aa58a1149d2dc07311 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 16:43:10 +0100
Subject: [PATCH 24/41] www: update openapi
---
www/app/api/apis/DefaultApi.ts | 61 +++-------------------------------
1 file changed, 5 insertions(+), 56 deletions(-)
diff --git a/www/app/api/apis/DefaultApi.ts b/www/app/api/apis/DefaultApi.ts
index d51d42ca..5bb2e7e9 100644
--- a/www/app/api/apis/DefaultApi.ts
+++ b/www/app/api/apis/DefaultApi.ts
@@ -42,10 +42,6 @@ import {
UpdateTranscriptToJSON,
} from "../models";
-export interface RtcOfferRequest {
- rtcOffer: RtcOffer;
-}
-
export interface V1TranscriptDeleteRequest {
transcriptId: any;
}
@@ -56,6 +52,7 @@ export interface V1TranscriptGetRequest {
export interface V1TranscriptGetAudioMp3Request {
transcriptId: any;
+ token?: any;
}
export interface V1TranscriptGetAudioWaveformRequest {
@@ -132,58 +129,6 @@ export class DefaultApi extends runtime.BaseAPI {
return await response.value();
}
- /**
- * Rtc Offer
- */
- async rtcOfferRaw(
- requestParameters: RtcOfferRequest,
- initOverrides?: RequestInit | runtime.InitOverrideFunction,
- ): Promise> {
- if (
- requestParameters.rtcOffer === null ||
- requestParameters.rtcOffer === undefined
- ) {
- throw new runtime.RequiredError(
- "rtcOffer",
- "Required parameter requestParameters.rtcOffer was null or undefined when calling rtcOffer.",
- );
- }
-
- const queryParameters: any = {};
-
- const headerParameters: runtime.HTTPHeaders = {};
-
- headerParameters["Content-Type"] = "application/json";
-
- const response = await this.request(
- {
- path: `/offer`,
- method: "POST",
- headers: headerParameters,
- query: queryParameters,
- body: RtcOfferToJSON(requestParameters.rtcOffer),
- },
- initOverrides,
- );
-
- if (this.isJsonMime(response.headers.get("content-type"))) {
- return new runtime.JSONApiResponse(response);
- } else {
- return new runtime.TextApiResponse(response) as any;
- }
- }
-
- /**
- * Rtc Offer
- */
- async rtcOffer(
- requestParameters: RtcOfferRequest,
- initOverrides?: RequestInit | runtime.InitOverrideFunction,
- ): Promise {
- const response = await this.rtcOfferRaw(requestParameters, initOverrides);
- return await response.value();
- }
-
/**
* Transcript Delete
*/
@@ -325,6 +270,10 @@ export class DefaultApi extends runtime.BaseAPI {
const queryParameters: any = {};
+ if (requestParameters.token !== undefined) {
+ queryParameters["token"] = requestParameters.token;
+ }
+
const headerParameters: runtime.HTTPHeaders = {};
if (this.configuration && this.configuration.accessToken) {
From 19b5ba2c4c40c08f8cff57e71a7f62829436504b Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 16:53:57 +0100
Subject: [PATCH 25/41] server: add diarization logger information
---
.../processors/audio_diarization_modal.py | 22 ++++++++++++-------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
index 52be7c5d..3c8d1b45 100644
--- a/server/reflector/processors/audio_diarization_modal.py
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -22,15 +22,21 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
"audio_file_url": data.audio_url,
"timestamp": 0,
}
+ self.logger.info("Diarization started", audio_file_url=data.audio_url)
async with httpx.AsyncClient() as client:
- response = await client.post(
- self.diarization_url,
- headers=self.headers,
- params=params,
- timeout=None,
- )
- response.raise_for_status()
- return response.json()["text"]
+ try:
+ response = await client.post(
+ self.diarization_url,
+ headers=self.headers,
+ params=params,
+ timeout=None,
+ )
+ response.raise_for_status()
+ self.logger.info("Diarization finished")
+ return response.json()["text"]
+ except Exception:
+ self.logger.exception("Diarization failed after retrying")
+ raise
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
From 057c636c56c0529086f89811cada698c98251962 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 16:56:30 +0100
Subject: [PATCH 26/41] server: move logging to base implementation, not
specialization
---
.../reflector/processors/audio_diarization.py | 8 ++++++-
.../processors/audio_diarization_modal.py | 22 +++++++------------
2 files changed, 15 insertions(+), 15 deletions(-)
diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py
index d69f4b80..82c6a553 100644
--- a/server/reflector/processors/audio_diarization.py
+++ b/server/reflector/processors/audio_diarization.py
@@ -7,7 +7,13 @@ class AudioDiarizationProcessor(Processor):
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
- diarization = await self._diarize(data)
+ try:
+ self.logger.info("Diarization started", audio_file_url=data.audio_url)
+ diarization = await self._diarize(data)
+ self.logger.info("Diarization finished")
+ except Exception:
+ self.logger.exception("Diarization failed after retrying")
+ raise
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
index 3c8d1b45..52be7c5d 100644
--- a/server/reflector/processors/audio_diarization_modal.py
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -22,21 +22,15 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
"audio_file_url": data.audio_url,
"timestamp": 0,
}
- self.logger.info("Diarization started", audio_file_url=data.audio_url)
async with httpx.AsyncClient() as client:
- try:
- response = await client.post(
- self.diarization_url,
- headers=self.headers,
- params=params,
- timeout=None,
- )
- response.raise_for_status()
- self.logger.info("Diarization finished")
- return response.json()["text"]
- except Exception:
- self.logger.exception("Diarization failed after retrying")
- raise
+ response = await client.post(
+ self.diarization_url,
+ headers=self.headers,
+ params=params,
+ timeout=None,
+ )
+ response.raise_for_status()
+ return response.json()["text"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
From 2e738e9f17dc8bec18aa9fae153808d5401ddf60 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 17:31:47 +0100
Subject: [PATCH 27/41] server: remove unused/old post_transcript.py
---
server/reflector/worker/post_transcript.py | 170 ---------------------
1 file changed, 170 deletions(-)
delete mode 100644 server/reflector/worker/post_transcript.py
diff --git a/server/reflector/worker/post_transcript.py b/server/reflector/worker/post_transcript.py
deleted file mode 100644
index 1cfbc664..00000000
--- a/server/reflector/worker/post_transcript.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from reflector.logger import logger
-from reflector.processors import (
- Pipeline,
- Processor,
- TranscriptFinalLongSummaryProcessor,
- TranscriptFinalShortSummaryProcessor,
- TranscriptFinalTitleProcessor,
-)
-from reflector.processors.base import BroadcastProcessor
-from reflector.processors.types import (
- FinalLongSummary,
- FinalShortSummary,
- FinalTitle,
- TitleSummary,
-)
-from reflector.processors.types import Transcript as ProcessorTranscript
-from reflector.tasks.worker import celery
-from reflector.views.rtc_offer import PipelineEvent, TranscriptionContext
-from reflector.views.transcripts import Transcript, transcripts_controller
-
-
-class TranscriptAudioDiarizationProcessor(Processor):
- INPUT_TYPE = Transcript
- OUTPUT_TYPE = TitleSummary
-
- async def _push(self, data: Transcript):
- # Gather diarization data
- diarization = [
- {"start": 0.0, "stop": 4.9, "speaker": 2},
- {"start": 5.6, "stop": 6.7, "speaker": 2},
- {"start": 7.3, "stop": 8.9, "speaker": 2},
- {"start": 7.3, "stop": 7.9, "speaker": 0},
- {"start": 9.4, "stop": 11.2, "speaker": 2},
- {"start": 9.7, "stop": 10.0, "speaker": 0},
- {"start": 10.0, "stop": 10.1, "speaker": 0},
- {"start": 11.7, "stop": 16.1, "speaker": 2},
- {"start": 11.8, "stop": 12.1, "speaker": 1},
- {"start": 16.4, "stop": 21.0, "speaker": 2},
- {"start": 21.1, "stop": 22.6, "speaker": 2},
- {"start": 24.7, "stop": 31.9, "speaker": 2},
- {"start": 32.0, "stop": 32.8, "speaker": 1},
- {"start": 33.4, "stop": 37.8, "speaker": 2},
- {"start": 37.9, "stop": 40.3, "speaker": 0},
- {"start": 39.2, "stop": 40.4, "speaker": 2},
- {"start": 40.7, "stop": 41.4, "speaker": 0},
- {"start": 41.6, "stop": 45.7, "speaker": 2},
- {"start": 46.4, "stop": 53.1, "speaker": 2},
- {"start": 53.6, "stop": 56.5, "speaker": 2},
- {"start": 54.9, "stop": 75.4, "speaker": 1},
- {"start": 57.3, "stop": 58.0, "speaker": 2},
- {"start": 65.7, "stop": 66.0, "speaker": 2},
- {"start": 75.8, "stop": 78.8, "speaker": 1},
- {"start": 79.0, "stop": 82.6, "speaker": 1},
- {"start": 83.2, "stop": 83.3, "speaker": 1},
- {"start": 84.5, "stop": 94.3, "speaker": 1},
- {"start": 95.1, "stop": 100.7, "speaker": 1},
- {"start": 100.7, "stop": 102.0, "speaker": 0},
- {"start": 100.7, "stop": 101.8, "speaker": 1},
- {"start": 102.0, "stop": 103.0, "speaker": 1},
- {"start": 103.0, "stop": 103.7, "speaker": 0},
- {"start": 103.7, "stop": 103.8, "speaker": 1},
- {"start": 103.8, "stop": 113.9, "speaker": 0},
- {"start": 114.7, "stop": 117.0, "speaker": 0},
- {"start": 117.0, "stop": 117.4, "speaker": 1},
- ]
-
- # now reapply speaker to topics (if any)
- # topics is a list[BaseModel] with an attribute words
- # words is a list[BaseModel] with text, start and speaker attribute
-
- # mutate in place
- for topic in data.topics:
- for word in topic.words:
- for d in diarization:
- if d["start"] <= word.start <= d["stop"]:
- word.speaker = d["speaker"]
-
- topics = data.topics[:]
-
- await transcripts_controller.update(
- data,
- {
- "topics": [topic.model_dump(mode="json") for topic in data.topics],
- },
- )
-
- # emit them
- for topic in topics:
- transcript = ProcessorTranscript(words=topic.words)
- await self.emit(
- TitleSummary(
- title=topic.title,
- summary=topic.summary,
- timestamp=topic.timestamp,
- duration=0,
- transcript=transcript,
- )
- )
-
-
-@celery.task(name="post_transcript")
-async def post_transcript_pipeline(transcript_id: str):
- # get transcript
- transcript = await transcripts_controller.get_by_id(transcript_id)
- if not transcript:
- logger.error("Transcript not found", transcript_id=transcript_id)
- return
-
- ctx = TranscriptionContext(logger=logger.bind(transcript_id=transcript_id))
- event_callback = None
- event_callback_args = None
-
- async def on_final_short_summary(summary: FinalShortSummary):
- ctx.logger.info("FinalShortSummary", final_short_summary=summary)
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.FINAL_SHORT_SUMMARY,
- args=event_callback_args,
- data=summary,
- )
-
- async def on_final_long_summary(summary: FinalLongSummary):
- ctx.logger.info("FinalLongSummary", final_summary=summary)
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.FINAL_LONG_SUMMARY,
- args=event_callback_args,
- data=summary,
- )
-
- async def on_final_title(title: FinalTitle):
- ctx.logger.info("FinalTitle", final_title=title)
-
- # send to callback (eg. websocket)
- if event_callback:
- await event_callback(
- event=PipelineEvent.FINAL_TITLE,
- args=event_callback_args,
- data=title,
- )
-
- ctx.logger.info("Starting pipeline (diarization)")
- ctx.pipeline = Pipeline(
- TranscriptAudioDiarizationProcessor(),
- BroadcastProcessor(
- processors=[
- TranscriptFinalTitleProcessor.as_threaded(),
- TranscriptFinalLongSummaryProcessor.as_threaded(),
- TranscriptFinalShortSummaryProcessor.as_threaded(),
- ]
- ),
- )
-
- await ctx.pipeline.push(transcript)
- await ctx.pipeline.flush()
-
-
-if __name__ == "__main__":
- import argparse
- import asyncio
-
- parser = argparse.ArgumentParser()
- parser.add_argument("transcript_id", type=str)
- args = parser.parse_args()
-
- asyncio.run(post_transcript_pipeline(args.transcript_id))
From 239fae6189dbd7864bfbe46d3ab311d8324c7755 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 19:02:02 +0100
Subject: [PATCH 28/41] hotfix/server: add migration script to migrate
transcript field to text
---
server/migrations/versions/9920ecfe2735_.py | 81 +++++++++++++++++++++
1 file changed, 81 insertions(+)
create mode 100644 server/migrations/versions/9920ecfe2735_.py
diff --git a/server/migrations/versions/9920ecfe2735_.py b/server/migrations/versions/9920ecfe2735_.py
new file mode 100644
index 00000000..da718af0
--- /dev/null
+++ b/server/migrations/versions/9920ecfe2735_.py
@@ -0,0 +1,81 @@
+"""empty message
+
+Revision ID: 9920ecfe2735
+Revises: 99365b0cd87b
+Create Date: 2023-11-02 18:55:17.019498
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+import json
+from sqlalchemy.sql import table, column
+from sqlalchemy import select
+
+
+# revision identifiers, used by Alembic.
+revision: str = "9920ecfe2735"
+down_revision: Union[str, None] = "99365b0cd87b"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # bind the engine
+ bind = op.get_bind()
+
+ # Reflect the table
+ transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
+
+ # Select all rows from the transcript table
+ results = bind.execute(select([transcript.c.id, transcript.c.topics]))
+
+ for row in results:
+ transcript_id = row["id"]
+ topics_json = row["topics"]
+
+ # Process each topic in the topics JSON array
+ updated_topics = []
+ for topic in topics_json:
+ if "transcript" in topic:
+ # Rename key 'transcript' to 'text'
+ topic["text"] = topic.pop("transcript")
+ updated_topics.append(topic)
+
+ # Update the transcript table
+ bind.execute(
+ transcript.update()
+ .where(transcript.c.id == transcript_id)
+ .values(topics=updated_topics)
+ )
+
+
+def downgrade() -> None:
+ # bind the engine
+ bind = op.get_bind()
+
+ # Reflect the table
+ transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
+
+ # Select all rows from the transcript table
+ results = bind.execute(select([transcript.c.id, transcript.c.topics]))
+
+ for row in results:
+ transcript_id = row["id"]
+ topics_json = row["topics"]
+
+ # Process each topic in the topics JSON array
+ updated_topics = []
+ for topic in topics_json:
+ if "text" in topic:
+ # Rename key 'text' back to 'transcript'
+ topic["transcript"] = topic.pop("text")
+ updated_topics.append(topic)
+
+ # Update the transcript table
+ bind.execute(
+ transcript.update()
+ .where(transcript.c.id == transcript_id)
+ .values(topics=updated_topics)
+ )
From 3424550ea9f477d78a7b789b2b98fd6d210453cd Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 19:06:30 +0100
Subject: [PATCH 29/41] hotfix/server: add id in GetTranscriptTopic for the
frontend to work
---
server/reflector/views/transcripts.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index f83bc6de..f724bcdc 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -117,6 +117,7 @@ class GetTranscriptSegmentTopic(BaseModel):
class GetTranscriptTopic(BaseModel):
+ id: str
title: str
summary: str
timestamp: float
@@ -149,6 +150,7 @@ class GetTranscriptTopic(BaseModel):
for segment in transcript.as_segments()
]
return cls(
+ id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
From c87c30d33952709a19ff6444b4fb788893f790bb Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 19:09:13 +0100
Subject: [PATCH 30/41] hotfix/server: add follow_redirect on modal
---
server/reflector/llm/llm_modal.py | 1 +
server/reflector/processors/audio_diarization_modal.py | 1 +
server/reflector/processors/audio_transcript_modal.py | 1 +
server/reflector/processors/transcript_translator.py | 1 +
4 files changed, 4 insertions(+)
diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py
index 220730e5..4b81c5a0 100644
--- a/server/reflector/llm/llm_modal.py
+++ b/server/reflector/llm/llm_modal.py
@@ -47,6 +47,7 @@ class ModalLLM(LLM):
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
+ follow_redirects=True,
)
response.raise_for_status()
text = response.json()["text"]
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
index 52be7c5d..53de2501 100644
--- a/server/reflector/processors/audio_diarization_modal.py
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -28,6 +28,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
headers=self.headers,
params=params,
timeout=None,
+ follow_redirects=True,
)
response.raise_for_status()
return response.json()["text"]
diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py
index 23c9d74e..0ca4710f 100644
--- a/server/reflector/processors/audio_transcript_modal.py
+++ b/server/reflector/processors/audio_transcript_modal.py
@@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
timeout=self.timeout,
headers=self.headers,
params=json_payload,
+ follow_redirects=True,
)
self.logger.debug(
diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py
index 77b8f5be..905ea423 100644
--- a/server/reflector/processors/transcript_translator.py
+++ b/server/reflector/processors/transcript_translator.py
@@ -50,6 +50,7 @@ class TranscriptTranslatorProcessor(Processor):
headers=self.headers,
params=json_payload,
timeout=self.timeout,
+ follow_redirects=True,
)
response.raise_for_status()
result = response.json()["text"]
From 37f6fe634544bbf38345d290afe21872ee7016ab Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 19:17:34 +0100
Subject: [PATCH 31/41] server: rename migration script for readability
---
...20ecfe2735_.py => 9920ecfe2735_rename_transcript_to_text.py} | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
rename server/migrations/versions/{9920ecfe2735_.py => 9920ecfe2735_rename_transcript_to_text.py} (97%)
diff --git a/server/migrations/versions/9920ecfe2735_.py b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py
similarity index 97%
rename from server/migrations/versions/9920ecfe2735_.py
rename to server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py
index da718af0..90ff85ac 100644
--- a/server/migrations/versions/9920ecfe2735_.py
+++ b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py
@@ -1,4 +1,4 @@
-"""empty message
+"""Migration transcript to text field in transcripts table
Revision ID: 9920ecfe2735
Revises: 99365b0cd87b
From 9642d0fd1e201fc804d61a35cbeb0c5a5217e4ab Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 19:40:45 +0100
Subject: [PATCH 32/41] hotfix/server: fix duplication of topics
---
server/reflector/db/transcripts.py | 6 +++---
server/reflector/pipelines/main_live_pipeline.py | 9 +++++++--
server/reflector/processors/types.py | 6 +++++-
3 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 89025d53..5d190bdc 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -106,9 +106,9 @@ class Transcript(BaseModel):
return ev
def upsert_topic(self, topic: TranscriptTopic):
- existing_topic = next((t for t in self.topics if t.id == topic.id), None)
- if existing_topic:
- existing_topic.update_from(topic)
+ index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None)
+ if index is not None:
+ self.topics[index] = topic
else:
self.topics.append(topic)
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 88e1bffd..bf11bdf3 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -46,7 +46,9 @@ from reflector.processors import (
TranscriptTranslatorProcessor,
)
from reflector.processors.types import AudioDiarizationInput
-from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
+from reflector.processors.types import (
+ TitleSummaryWithId as TitleSummaryWithIdProcessorType,
+)
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
@@ -163,6 +165,8 @@ class PipelineMainBase(PipelineRunner):
text=data.transcript.text,
words=data.transcript.words,
)
+ if isinstance(data, TitleSummaryWithIdProcessorType):
+ topic.id = data.id
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
@@ -302,7 +306,8 @@ class PipelineMainDiarization(PipelineMainBase):
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
topics = [
- TitleSummaryProcessorType(
+ TitleSummaryWithIdProcessorType(
+ id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py
index b67f84b9..312f5433 100644
--- a/server/reflector/processors/types.py
+++ b/server/reflector/processors/types.py
@@ -167,6 +167,10 @@ class TitleSummary(BaseModel):
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
+class TitleSummaryWithId(TitleSummary):
+ id: str
+
+
class FinalLongSummary(BaseModel):
long_summary: str
duration: float
@@ -386,4 +390,4 @@ class TranslationLanguages(BaseModel):
class AudioDiarizationInput(BaseModel):
audio_url: str
- topics: list[TitleSummary]
+ topics: list[TitleSummaryWithId]
From eb76cd9bcd4540562ca3e658466904571a638406 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 19:54:44 +0100
Subject: [PATCH 33/41] server/www: rename topic text field to transcript
This aleviate the current issue with vercel deployment
---
...27dcb099_rename_back_text_to_transcript.py | 80 +++++++++++++++++++
.../9920ecfe2735_rename_transcript_to_text.py | 1 -
server/reflector/db/transcripts.py | 2 +-
.../reflector/pipelines/main_live_pipeline.py | 2 +-
server/reflector/views/transcripts.py | 4 +-
server/tests/test_transcripts_rtc_ws.py | 4 +-
www/app/[domain]/transcripts/topicList.tsx | 2 +-
.../[domain]/transcripts/webSocketTypes.ts | 2 +-
www/app/api/models/GetTranscriptTopic.ts | 17 +++-
9 files changed, 101 insertions(+), 13 deletions(-)
create mode 100644 server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py
diff --git a/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py b/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py
new file mode 100644
index 00000000..dffe6fa1
--- /dev/null
+++ b/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py
@@ -0,0 +1,80 @@
+"""rename back text to transcript
+
+Revision ID: 38a927dcb099
+Revises: 9920ecfe2735
+Create Date: 2023-11-02 19:53:09.116240
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.sql import table, column
+from sqlalchemy import select
+
+
+# revision identifiers, used by Alembic.
+revision: str = '38a927dcb099'
+down_revision: Union[str, None] = '9920ecfe2735'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # bind the engine
+ bind = op.get_bind()
+
+ # Reflect the table
+ transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
+
+ # Select all rows from the transcript table
+ results = bind.execute(select([transcript.c.id, transcript.c.topics]))
+
+ for row in results:
+ transcript_id = row["id"]
+ topics_json = row["topics"]
+
+ # Process each topic in the topics JSON array
+ updated_topics = []
+ for topic in topics_json:
+ if "text" in topic:
+ # Rename key 'text' back to 'transcript'
+ topic["transcript"] = topic.pop("text")
+ updated_topics.append(topic)
+
+ # Update the transcript table
+ bind.execute(
+ transcript.update()
+ .where(transcript.c.id == transcript_id)
+ .values(topics=updated_topics)
+ )
+
+
+def downgrade() -> None:
+ # bind the engine
+ bind = op.get_bind()
+
+ # Reflect the table
+ transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
+
+ # Select all rows from the transcript table
+ results = bind.execute(select([transcript.c.id, transcript.c.topics]))
+
+ for row in results:
+ transcript_id = row["id"]
+ topics_json = row["topics"]
+
+ # Process each topic in the topics JSON array
+ updated_topics = []
+ for topic in topics_json:
+ if "transcript" in topic:
+ # Rename key 'transcript' to 'text'
+ topic["text"] = topic.pop("transcript")
+ updated_topics.append(topic)
+
+ # Update the transcript table
+ bind.execute(
+ transcript.update()
+ .where(transcript.c.id == transcript_id)
+ .values(topics=updated_topics)
+ )
diff --git a/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py
index 90ff85ac..caecaefd 100644
--- a/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py
+++ b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py
@@ -9,7 +9,6 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
-import json
from sqlalchemy.sql import table, column
from sqlalchemy import select
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 5d190bdc..6ac2e32a 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -63,7 +63,7 @@ class TranscriptTopic(BaseModel):
summary: str
timestamp: float
duration: float | None = 0
- text: str | None = None
+ transcript: str | None = None
words: list[ProcessorWord] = []
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index bf11bdf3..477b0ce9 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -162,7 +162,7 @@ class PipelineMainBase(PipelineRunner):
title=data.title,
summary=data.summary,
timestamp=data.timestamp,
- text=data.transcript.text,
+ transcript=data.transcript.text,
words=data.transcript.words,
)
if isinstance(data, TitleSummaryWithIdProcessorType):
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index f724bcdc..77c3e149 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -121,7 +121,7 @@ class GetTranscriptTopic(BaseModel):
title: str
summary: str
timestamp: float
- text: str
+ transcript: str
segments: list[GetTranscriptSegmentTopic] = []
@classmethod
@@ -154,7 +154,7 @@ class GetTranscriptTopic(BaseModel):
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
- text=text,
+ transcript=text,
segments=segments,
)
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index 8f8cac71..413c8b24 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -167,7 +167,7 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
- assert ev["data"]["text"].startswith("Hello world.")
+ assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -316,7 +316,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
- assert ev["data"]["text"].startswith("Hello world.")
+ assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx
index 56b02e6e..e7454f79 100644
--- a/www/app/[domain]/transcripts/topicList.tsx
+++ b/www/app/[domain]/transcripts/topicList.tsx
@@ -132,7 +132,7 @@ export function TopicList({
))}
>
) : (
- <>{topic.text}>
+ <>{topic.transcript}>
)}
)}
diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts
index abc67b33..112e7cc0 100644
--- a/www/app/[domain]/transcripts/webSocketTypes.ts
+++ b/www/app/[domain]/transcripts/webSocketTypes.ts
@@ -9,7 +9,7 @@ export type Topic = {
title: string;
summary: string;
id: string;
- text: string;
+ transcript: string;
segments: SegmentTopic[];
};
diff --git a/www/app/api/models/GetTranscriptTopic.ts b/www/app/api/models/GetTranscriptTopic.ts
index 7a7d4c90..460b8b39 100644
--- a/www/app/api/models/GetTranscriptTopic.ts
+++ b/www/app/api/models/GetTranscriptTopic.ts
@@ -19,6 +19,12 @@ import { exists, mapValues } from "../runtime";
* @interface GetTranscriptTopic
*/
export interface GetTranscriptTopic {
+ /**
+ *
+ * @type {any}
+ * @memberof GetTranscriptTopic
+ */
+ id: any | null;
/**
*
* @type {any}
@@ -42,7 +48,7 @@ export interface GetTranscriptTopic {
* @type {any}
* @memberof GetTranscriptTopic
*/
- text: any | null;
+ transcript: any | null;
/**
*
* @type {any}
@@ -56,10 +62,11 @@ export interface GetTranscriptTopic {
*/
export function instanceOfGetTranscriptTopic(value: object): boolean {
let isInstance = true;
+ isInstance = isInstance && "id" in value;
isInstance = isInstance && "title" in value;
isInstance = isInstance && "summary" in value;
isInstance = isInstance && "timestamp" in value;
- isInstance = isInstance && "text" in value;
+ isInstance = isInstance && "transcript" in value;
return isInstance;
}
@@ -76,10 +83,11 @@ export function GetTranscriptTopicFromJSONTyped(
return json;
}
return {
+ id: json["id"],
title: json["title"],
summary: json["summary"],
timestamp: json["timestamp"],
- text: json["text"],
+ transcript: json["transcript"],
segments: !exists(json, "segments") ? undefined : json["segments"],
};
}
@@ -94,10 +102,11 @@ export function GetTranscriptTopicToJSON(
return null;
}
return {
+ id: value.id,
title: value.title,
summary: value.summary,
timestamp: value.timestamp,
- text: value.text,
+ transcript: value.transcript,
segments: value.segments,
};
}
From aee959369f3a79859fa2cf992f6c7a1178c13264 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 20:08:21 +0100
Subject: [PATCH 34/41] hotfix/server: correctly load old topic
---
server/reflector/views/transcripts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 77c3e149..7d79bf72 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -129,7 +129,7 @@ class GetTranscriptTopic(BaseModel):
if not topic.words:
# In previous version, words were missing
# Just output a segment with speaker 0
- text = topic.text
+ text = topic.transcript
segments = [
GetTranscriptSegmentTopic(
text=topic.text,
From 6f3e3741e788dbe0487d6e2a5e4f536428f2464b Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 20:12:57 +0100
Subject: [PATCH 35/41] hotfix/server: fix crash
---
server/reflector/views/transcripts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 7d79bf72..e3668ecb 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -132,7 +132,7 @@ class GetTranscriptTopic(BaseModel):
text = topic.transcript
segments = [
GetTranscriptSegmentTopic(
- text=topic.text,
+ text=topic.transcript,
start=topic.timestamp,
speaker=0,
)
From c5893e0391a1d7b6bcd7dd9c7b0dc745e03e8a24 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 20:34:53 +0100
Subject: [PATCH 36/41] hotfix/server: do not pass a token for diarization/mp3
if there is no user
When decoding the token, if it is invalid (sub cannot be None), it just fail
---
server/reflector/pipelines/main_live_pipeline.py | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 477b0ce9..b2bc51ea 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -322,15 +322,19 @@ class PipelineMainDiarization(PipelineMainBase):
# from the diarization processor
from reflector.views.transcripts import create_access_token
- token = create_access_token(
- {"sub": transcript.user_id},
- expires_delta=timedelta(minutes=15),
- )
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=transcript.id,
)
- url = f"{settings.BASE_URL}{path}?token={token}"
+ url = f"{settings.BASE_URL}{path}"
+ if transcript.user_id:
+ # we pass token only if the user_id is set
+ # otherwise, the audio is public
+ token = create_access_token(
+ {"sub": transcript.user_id},
+ expires_delta=timedelta(minutes=15),
+ )
+ url += f"?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_url=url,
topics=topics,
From b9149d6e68116a9525e3495332928b455e09b7b1 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 21:00:14 +0100
Subject: [PATCH 37/41] server: ensure retry works even with 303 redirection
---
server/tests/test_retry_decorator.py | 26 ++++++++++++++++++++++++++
1 file changed, 26 insertions(+)
diff --git a/server/tests/test_retry_decorator.py b/server/tests/test_retry_decorator.py
index 22729eac..c60a490f 100644
--- a/server/tests/test_retry_decorator.py
+++ b/server/tests/test_retry_decorator.py
@@ -1,3 +1,4 @@
+import asyncio
import pytest
import httpx
from reflector.utils.retry import (
@@ -8,6 +9,31 @@ from reflector.utils.retry import (
)
+@pytest.mark.asyncio
+async def test_retry_redirect(httpx_mock):
+ async def custom_response(request: httpx.Request):
+ if request.url.path == "/hello":
+ await asyncio.sleep(1)
+ return httpx.Response(
+ status_code=303, headers={"location": "https://test_url/redirected"}
+ )
+ elif request.url.path == "/redirected":
+ return httpx.Response(status_code=200, json={"hello": "world"})
+ else:
+ raise Exception("Unexpected path")
+
+ httpx_mock.add_callback(custom_response)
+ async with httpx.AsyncClient() as client:
+ # timeout should not triggered, as it will end up ok
+ # even though the first request is a 303 and took more that 0.5
+ resp = await retry(client.get)(
+ "https://test_url/hello",
+ retry_timeout=0.5,
+ follow_redirects=True,
+ )
+ assert resp.json() == {"hello": "world"}
+
+
@pytest.mark.asyncio
async def test_retry_httpx(httpx_mock):
# this code should be force a retry
From da926b79a00042875868c73a8b2aacfc1a548434 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 3 Nov 2023 10:34:14 +0100
Subject: [PATCH 38/41] Update README.md - trigger vercel build after outage
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 150c4575..cb75c76b 100644
--- a/README.md
+++ b/README.md
@@ -169,3 +169,4 @@ poetry run python -m reflector.tools.process path/to/audio.wav
## AI Models
*(Documentation for this section is pending.)*
+
From 5f773a1e82bdf1d4afdb81d3c0042cb63609e0ed Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 3 Nov 2023 10:36:16 +0100
Subject: [PATCH 39/41] Update forbidden.tsx - edit to trigger vercel build
---
www/pages/forbidden.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/www/pages/forbidden.tsx b/www/pages/forbidden.tsx
index 31a746fc..ada3d424 100644
--- a/www/pages/forbidden.tsx
+++ b/www/pages/forbidden.tsx
@@ -1,7 +1,7 @@
import type { NextPage } from "next";
const Forbidden: NextPage = () => {
- return Sorry, you are not authorized to access this page.
;
+ return Sorry, you are not authorized to access this page
;
};
export default Forbidden;
From 22b1ce9fd208f3603c2923181aa6fb1a211dd455 Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Fri, 3 Nov 2023 11:22:10 +0100
Subject: [PATCH 40/41] www: fix build with text->transcript and duplication of
topics
---
www/app/[domain]/transcripts/useWebSockets.ts | 38 +++++++++++++------
.../[domain]/transcripts/webSocketTypes.ts | 15 +-------
2 files changed, 29 insertions(+), 24 deletions(-)
diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts
index 8196749e..5610c2a4 100644
--- a/www/app/[domain]/transcripts/useWebSockets.ts
+++ b/www/app/[domain]/transcripts/useWebSockets.ts
@@ -56,7 +56,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
timestamp: 10,
summary: "This is test topic 1",
title: "Topic 1: Introduction to Quantum Mechanics",
- text: "A brief overview of quantum mechanics and its principles.",
+ transcript:
+ "A brief overview of quantum mechanics and its principles.",
segments: [
{
speaker: 1,
@@ -96,7 +97,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
timestamp: 20,
summary: "This is test topic 2",
title: "Topic 2: Machine Learning Algorithms",
- text: "Understanding the different types of machine learning algorithms.",
+ transcript:
+ "Understanding the different types of machine learning algorithms.",
segments: [
{
speaker: 1,
@@ -115,7 +117,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
timestamp: 30,
summary: "This is test topic 3",
title: "Topic 3: Mental Health Awareness",
- text: "Ways to improve mental health and reduce stigma.",
+ transcript: "Ways to improve mental health and reduce stigma.",
segments: [
{
speaker: 1,
@@ -134,7 +136,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
timestamp: 40,
summary: "This is test topic 4",
title: "Topic 4: Basics of Productivity",
- text: "Tips and tricks to increase daily productivity.",
+ transcript: "Tips and tricks to increase daily productivity.",
segments: [
{
speaker: 1,
@@ -153,7 +155,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
timestamp: 50,
summary: "This is test topic 5",
title: "Topic 5: Future of Aviation",
- text: "Exploring the advancements and possibilities in aviation.",
+ transcript:
+ "Exploring the advancements and possibilities in aviation.",
segments: [
{
speaker: 1,
@@ -182,7 +185,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 1",
title:
"Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.",
- text: "A brief overview of quantum mechanics and its principles.",
+ transcript:
+ "A brief overview of quantum mechanics and its principles.",
segments: [
{
speaker: 1,
@@ -202,7 +206,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 2",
title:
"Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.",
- text: "Understanding the different types of machine learning algorithms.",
+ transcript:
+ "Understanding the different types of machine learning algorithms.",
segments: [
{
speaker: 1,
@@ -222,7 +227,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 3",
title:
"Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.",
- text: "Ways to improve mental health and reduce stigma.",
+ transcript: "Ways to improve mental health and reduce stigma.",
segments: [
{
speaker: 1,
@@ -242,7 +247,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 4",
title:
"Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.",
- text: "Tips and tricks to increase daily productivity.",
+ transcript: "Tips and tricks to increase daily productivity.",
segments: [
{
speaker: 1,
@@ -262,7 +267,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "This is test topic 5",
title:
"Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.",
- text: "Exploring the advancements and possibilities in aviation.",
+ transcript:
+ "Exploring the advancements and possibilities in aviation.",
segments: [
{
speaker: 1,
@@ -308,7 +314,17 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
break;
case "TOPIC":
- setTopics((prevTopics) => [...prevTopics, message.data]);
+ setTopics((prevTopics) => {
+ const topic = message.data as Topic;
+ const index = prevTopics.findIndex(
+ (prevTopic) => prevTopic.id === topic.id,
+ );
+ if (index >= 0) {
+ prevTopics[index] = topic;
+ return prevTopics;
+ }
+ return [...prevTopics, topic];
+ });
console.debug("TOPIC event:", message.data);
break;
diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts
index 112e7cc0..edd35eb6 100644
--- a/www/app/[domain]/transcripts/webSocketTypes.ts
+++ b/www/app/[domain]/transcripts/webSocketTypes.ts
@@ -1,17 +1,6 @@
-export type SegmentTopic = {
- speaker: number;
- start: number;
- text: string;
-};
+import { GetTranscriptTopic } from "../../api";
-export type Topic = {
- timestamp: number;
- title: string;
- summary: string;
- id: string;
- transcript: string;
- segments: SegmentTopic[];
-};
+export type Topic = GetTranscriptTopic;
export type Transcript = {
text: string;
From 907f4be67ab8f115603678fde3f3e7a064bc6f8c Mon Sep 17 00:00:00 2001
From: Mathieu Virbel
Date: Thu, 2 Nov 2023 11:46:45 +0100
Subject: [PATCH 41/41] www: fix mp3 download while authenticated
---
.../transcripts/[transcriptId]/page.tsx | 3 ++
www/app/[domain]/transcripts/recorder.tsx | 15 ++++--
www/app/[domain]/transcripts/useMp3.ts | 54 ++++++++++++++-----
www/package.json | 2 +-
www/yarn.lock | 8 +--
5 files changed, 60 insertions(+), 22 deletions(-)
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
index 3e30b97f..8e2184a0 100644
--- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx
+++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx
@@ -4,6 +4,7 @@ import getApi from "../../../lib/getApi";
import useTranscript from "../useTranscript";
import useTopics from "../useTopics";
import useWaveform from "../useWaveform";
+import useMp3 from "../useMp3";
import { TopicList } from "../topicList";
import Recorder from "../recorder";
import { Topic } from "../webSocketTypes";
@@ -31,6 +32,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
const waveform = useWaveform(api, transcriptId);
const useActiveTopic = useState(null);
const requireLogin = featureEnabled("requireLogin");
+ const mp3 = useMp3(api, transcriptId);
useEffect(() => {
if (requireLogin && !isAuthenticated) return;
@@ -70,6 +72,7 @@ export default function TranscriptDetails(details: TranscriptDetails) {
waveform={waveform?.waveform}
isPastMeeting={true}
transcriptId={transcript?.response?.id}
+ mp3Blob={mp3.blob}
/>
)}
diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx
index 401b6c9e..d50c90e3 100644
--- a/www/app/[domain]/transcripts/recorder.tsx
+++ b/www/app/[domain]/transcripts/recorder.tsx
@@ -29,6 +29,7 @@ type RecorderProps = {
waveform?: AudioWaveform | null;
isPastMeeting: boolean;
transcriptId?: string | null;
+ mp3Blob?: Blob | null;
};
export default function Recorder(props: RecorderProps) {
@@ -107,11 +108,7 @@ export default function Recorder(props: RecorderProps) {
if (waveformRef.current) {
const _wavesurfer = WaveSurfer.create({
container: waveformRef.current,
- url: props.transcriptId
- ? `${process.env.NEXT_PUBLIC_API_URL}/v1/transcripts/${props.transcriptId}/audio/mp3`
- : undefined,
peaks: props.waveform?.data,
-
hideScrollbar: true,
autoCenter: true,
barWidth: 2,
@@ -145,6 +142,10 @@ export default function Recorder(props: RecorderProps) {
if (props.isPastMeeting) _wavesurfer.toggleInteraction(true);
+ if (props.mp3Blob) {
+ _wavesurfer.loadBlob(props.mp3Blob);
+ }
+
setWavesurfer(_wavesurfer);
return () => {
@@ -156,6 +157,12 @@ export default function Recorder(props: RecorderProps) {
}
}, []);
+ useEffect(() => {
+ if (!wavesurfer) return;
+ if (!props.mp3Blob) return;
+ wavesurfer.loadBlob(props.mp3Blob);
+ }, [props.mp3Blob]);
+
useEffect(() => {
topicsRef.current = props.topics;
if (!isRecording) renderMarkers();
diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts
index 8bccf903..b7677180 100644
--- a/www/app/[domain]/transcripts/useMp3.ts
+++ b/www/app/[domain]/transcripts/useMp3.ts
@@ -1,36 +1,64 @@
-import { useEffect, useState } from "react";
+import { useContext, useEffect, useState } from "react";
import {
DefaultApi,
- V1TranscriptGetAudioMp3Request,
+ // V1TranscriptGetAudioMp3Request,
} from "../../api/apis/DefaultApi";
import {} from "../../api";
import { useError } from "../../(errors)/errorContext";
+import { DomainContext } from "../domainContext";
type Mp3Response = {
url: string | null;
+ blob: Blob | null;
loading: boolean;
error: Error | null;
};
const useMp3 = (api: DefaultApi, id: string): Mp3Response => {
const [url, setUrl] = useState(null);
+ const [blob, setBlob] = useState(null);
const [loading, setLoading] = useState(false);
const [error, setErrorState] = useState(null);
const { setError } = useError();
+ const { api_url } = useContext(DomainContext);
const getMp3 = (id: string) => {
- if (!id) throw new Error("Transcript ID is required to get transcript Mp3");
+ if (!id) return;
setLoading(true);
- const requestParameters: V1TranscriptGetAudioMp3Request = {
- transcriptId: id,
- };
- api
- .v1TranscriptGetAudioMp3(requestParameters)
- .then((result) => {
- setUrl(result);
- setLoading(false);
- console.debug("Transcript Mp3 loaded:", result);
+ // XXX Current API interface does not output a blob, we need to to is manually
+ // const requestParameters: V1TranscriptGetAudioMp3Request = {
+ // transcriptId: id,
+ // };
+ // api
+ // .v1TranscriptGetAudioMp3(requestParameters)
+ // .then((result) => {
+ // setUrl(result);
+ // setLoading(false);
+ // console.debug("Transcript Mp3 loaded:", result);
+ // })
+ // .catch((err) => {
+ // setError(err);
+ // setErrorState(err);
+ // });
+ const localUrl = `${api_url}/v1/transcripts/${id}/audio/mp3`;
+ if (localUrl == url) return;
+ const headers = new Headers();
+
+ if (api.configuration.configuration.accessToken) {
+ headers.set("Authorization", api.configuration.configuration.accessToken);
+ }
+
+ fetch(localUrl, {
+ method: "GET",
+ headers,
+ })
+ .then((response) => {
+ setUrl(localUrl);
+ response.blob().then((blob) => {
+ setBlob(blob);
+ setLoading(false);
+ });
})
.catch((err) => {
setError(err);
@@ -42,7 +70,7 @@ const useMp3 = (api: DefaultApi, id: string): Mp3Response => {
getMp3(id);
}, [id]);
- return { url, loading, error };
+ return { url, blob, loading, error };
};
export default useMp3;
diff --git a/www/package.json b/www/package.json
index edbc0790..55c7df73 100644
--- a/www/package.json
+++ b/www/package.json
@@ -35,7 +35,7 @@
"supports-color": "^9.4.0",
"tailwindcss": "^3.3.2",
"typescript": "^5.1.6",
- "wavesurfer.js": "^7.0.3"
+ "wavesurfer.js": "^7.4.2"
},
"main": "index.js",
"repository": "https://github.com/Monadical-SAS/reflector-ui.git",
diff --git a/www/yarn.lock b/www/yarn.lock
index a67822be..8ec03382 100644
--- a/www/yarn.lock
+++ b/www/yarn.lock
@@ -2638,10 +2638,10 @@ watchpack@2.4.0:
glob-to-regexp "^0.4.1"
graceful-fs "^4.1.2"
-wavesurfer.js@^7.0.3:
- version "7.0.3"
- resolved "https://registry.npmjs.org/wavesurfer.js/-/wavesurfer.js-7.0.3.tgz"
- integrity sha512-gJ3P+Bd3Q4E8qETjjg0pneaVqm2J7jegG2Cc6vqEF5YDDKQ3m8sKsvVfgVhJkacKkO9jFAGDu58Hw4zLr7xD0A==
+wavesurfer.js@^7.4.2:
+ version "7.4.2"
+ resolved "https://registry.yarnpkg.com/wavesurfer.js/-/wavesurfer.js-7.4.2.tgz#59f5c87193d4eeeb199858688ddac1ad7ba86b3a"
+ integrity sha512-4pNQ1porOCUBYBmd2F1TqVuBnB2wBPipaw2qI920zYLuPnada0Rd1CURgh8HRuPGKxijj2iyZDFN2UZwsaEuhA==
wcwidth@>=1.0.1, wcwidth@^1.0.1:
version "1.0.1"