From b3232543763364dfd93fbae2c56ea89e385f91a6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 19 Oct 2023 21:05:13 +0200 Subject: [PATCH 01/41] server: move out profanity filter to transcript, and implement segmentation --- .../reflector/processors/audio_transcript.py | 10 --- .../processors/audio_transcript_modal.py | 3 - .../processors/audio_transcript_whisper.py | 1 - .../reflector/processors/transcript_liner.py | 1 - server/reflector/processors/types.py | 65 ++++++++++++++++++- server/reflector/views/transcripts.py | 17 ++++- 6 files changed, 78 insertions(+), 19 deletions(-) diff --git a/server/reflector/processors/audio_transcript.py b/server/reflector/processors/audio_transcript.py index f029b587..3f9dc85b 100644 --- a/server/reflector/processors/audio_transcript.py +++ b/server/reflector/processors/audio_transcript.py @@ -1,6 +1,4 @@ -from profanityfilter import ProfanityFilter from prometheus_client import Counter, Histogram - from reflector.processors.base import Processor from reflector.processors.types import AudioFile, Transcript @@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor): self.m_transcript_call = self.m_transcript_call.labels(name) self.m_transcript_success = self.m_transcript_success.labels(name) self.m_transcript_failure = self.m_transcript_failure.labels(name) - self.profanity_filter = ProfanityFilter() - self.profanity_filter.set_censor("*") super().__init__(*args, **kwargs) async def _push(self, data: AudioFile): @@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor): async def _transcript(self, data: AudioFile): raise NotImplementedError - - def filter_profanity(self, text: str) -> str: - """ - Remove censored words from the transcript - """ - return self.profanity_filter.censor(text) diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 201ed9d4..23c9d74e 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -48,10 +48,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): ) response.raise_for_status() result = response.json() - text = result["text"][source_language] - text = self.filter_profanity(text) transcript = Transcript( - text=text, words=[ Word( text=word["text"], diff --git a/server/reflector/processors/audio_transcript_whisper.py b/server/reflector/processors/audio_transcript_whisper.py index e3bd595b..cd96e01a 100644 --- a/server/reflector/processors/audio_transcript_whisper.py +++ b/server/reflector/processors/audio_transcript_whisper.py @@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor): ts = data.timestamp for segment in segments: - transcript.text += segment.text for word in segment.words: transcript.words.append( Word( diff --git a/server/reflector/processors/transcript_liner.py b/server/reflector/processors/transcript_liner.py index c1aa14a0..b4e7b5e3 100644 --- a/server/reflector/processors/transcript_liner.py +++ b/server/reflector/processors/transcript_liner.py @@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor): # cut to the next . partial = Transcript(words=[]) for word in self.transcript.words[:]: - partial.text += word.text partial.words.append(word) if not self.is_sentence_terminated(word.text): continue diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index e867becf..686c5785 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -2,8 +2,12 @@ import io import tempfile from pathlib import Path +from profanityfilter import ProfanityFilter from pydantic import BaseModel, PrivateAttr +profanity_filter = ProfanityFilter() +profanity_filter.set_censor("*") + class AudioFile(BaseModel): name: str @@ -43,13 +47,29 @@ class Word(BaseModel): text: str start: float end: float + speaker: int = 0 + + +class TranscriptSegment(BaseModel): + text: str + start: float + speaker: int = 0 class Transcript(BaseModel): - text: str = "" translation: str | None = None words: list[Word] = None + @property + def raw_text(self): + # Uncensored text + return "".join([word.text for word in self.words]) + + @property + def text(self): + # Censored text + return profanity_filter.censor(self.raw_text).strip() + @property def human_timestamp(self): minutes = int(self.timestamp / 60) @@ -74,7 +94,6 @@ class Transcript(BaseModel): self.words = other.words else: self.words.extend(other.words) - self.text += other.text def add_offset(self, offset: float): for word in self.words: @@ -87,6 +106,48 @@ class Transcript(BaseModel): ] return Transcript(text=self.text, translation=self.translation, words=words) + def as_segments(self): + # from a list of word, create a list of segments + # join the word that are less than 2 seconds apart + # but separate if the speaker changes, or if the punctuation is a . , ; : ? ! + segments = [] + current_segment = None + last_word = None + BLANK_TIME_SECS = 2 + MAX_SEGMENT_LENGTH = 80 + for word in self.words: + if current_segment is None: + current_segment = TranscriptSegment( + text=word.text, + start=word.start, + speaker=word.speaker, + ) + continue + is_blank = False + if last_word: + is_blank = word.start - last_word.end > BLANK_TIME_SECS + if ( + word.speaker != current_segment.speaker + or ( + word.text in ".;:?!…" + and len(current_segment.text) > MAX_SEGMENT_LENGTH + ) + or is_blank + ): + # check which condition triggered + segments.append(current_segment) + current_segment = TranscriptSegment( + text=word.text, + start=word.start, + speaker=word.speaker, + ) + else: + current_segment.text += word.text + last_word = word + if current_segment: + segments.append(current_segment) + return segments + class TitleSummary(BaseModel): title: str diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index a7e01b8c..0a068c17 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -49,12 +49,18 @@ class TranscriptText(BaseModel): translation: str | None +class TranscriptSegmentTopic(BaseModel): + speaker: int + text: str + timestamp: float + + class TranscriptTopic(BaseModel): id: str = Field(default_factory=generate_uuid4) title: str summary: str - transcript: str | None = None timestamp: float + segments: list[TranscriptSegmentTopic] = [] class TranscriptFinalShortSummary(BaseModel): @@ -523,8 +529,15 @@ async def handle_rtc_event(event: PipelineEvent, args, data): topic = TranscriptTopic( title=data.title, summary=data.summary, - transcript=data.transcript.text, timestamp=data.timestamp, + segments=[ + TranscriptSegmentTopic( + speaker=segment.speaker, + text=segment.text, + timestamp=segment.start, + ) + for segment in data.transcript.as_segments() + ], ) resp = transcript.add_event(event=event, data=topic) transcript.upsert_topic(topic) From 6d074ed4570dbf85b304d7d5771ca6fe94c02114 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 19 Oct 2023 21:05:49 +0200 Subject: [PATCH 02/41] www: update frontend to support new transcript format in topics --- .../test_processor_transcript_segment.py | 146 ++++++++++++++++++ www/app/[domain]/transcripts/topicList.tsx | 12 +- .../[domain]/transcripts/webSocketTypes.ts | 8 +- www/app/api/.openapi-generator/FILES | 1 + www/app/api/models/TranscriptSegmentTopic.ts | 88 +++++++++++ www/app/api/models/TranscriptTopic.ts | 8 +- www/app/api/models/index.ts | 1 + 7 files changed, 258 insertions(+), 6 deletions(-) create mode 100644 server/tests/test_processor_transcript_segment.py create mode 100644 www/app/api/models/TranscriptSegmentTopic.ts diff --git a/server/tests/test_processor_transcript_segment.py b/server/tests/test_processor_transcript_segment.py new file mode 100644 index 00000000..3bb6182f --- /dev/null +++ b/server/tests/test_processor_transcript_segment.py @@ -0,0 +1,146 @@ +def test_processor_transcript_segment(): + from reflector.processors.types import Transcript, Word + + transcript = Transcript( + words=[ + Word(text=" the", start=5.12, end=5.48, speaker=0), + Word(text=" different", start=5.48, end=5.8, speaker=0), + Word(text=" projects", start=5.8, end=6.3, speaker=0), + Word(text=" that", start=6.3, end=6.5, speaker=0), + Word(text=" are", start=6.5, end=6.58, speaker=0), + Word(text=" going", start=6.58, end=6.82, speaker=0), + Word(text=" on", start=6.82, end=7.26, speaker=0), + Word(text=" to", start=7.26, end=7.4, speaker=0), + Word(text=" give", start=7.4, end=7.54, speaker=0), + Word(text=" you", start=7.54, end=7.9, speaker=0), + Word(text=" context", start=7.9, end=8.24, speaker=0), + Word(text=" and", start=8.24, end=8.66, speaker=0), + Word(text=" I", start=8.66, end=8.72, speaker=0), + Word(text=" think", start=8.72, end=8.82, speaker=0), + Word(text=" that's", start=8.82, end=9.04, speaker=0), + Word(text=" what", start=9.04, end=9.12, speaker=0), + Word(text=" we'll", start=9.12, end=9.24, speaker=0), + Word(text=" do", start=9.24, end=9.32, speaker=0), + Word(text=" this", start=9.32, end=9.52, speaker=0), + Word(text=" week.", start=9.52, end=9.76, speaker=0), + Word(text=" Um,", start=10.24, end=10.62, speaker=0), + Word(text=" so,", start=11.36, end=11.94, speaker=0), + Word(text=" um,", start=12.46, end=12.92, speaker=0), + Word(text=" what", start=13.74, end=13.94, speaker=0), + Word(text=" we're", start=13.94, end=14.1, speaker=0), + Word(text=" going", start=14.1, end=14.24, speaker=0), + Word(text=" to", start=14.24, end=14.34, speaker=0), + Word(text=" do", start=14.34, end=14.8, speaker=0), + Word(text=" at", start=14.8, end=14.98, speaker=0), + Word(text=" H", start=14.98, end=15.04, speaker=0), + Word(text=" of", start=15.04, end=15.16, speaker=0), + Word(text=" you,", start=15.16, end=15.26, speaker=0), + Word(text=" maybe.", start=15.28, end=15.34, speaker=0), + Word(text=" you", start=15.36, end=15.52, speaker=0), + Word(text=" can", start=15.52, end=15.62, speaker=0), + Word(text=" introduce", start=15.62, end=15.98, speaker=0), + Word(text=" yourself", start=15.98, end=16.42, speaker=0), + Word(text=" to", start=16.42, end=16.68, speaker=0), + Word(text=" the", start=16.68, end=16.72, speaker=0), + Word(text=" team", start=16.72, end=17.52, speaker=0), + Word(text=" quickly", start=17.87, end=18.65, speaker=0), + Word(text=" and", start=18.65, end=19.63, speaker=0), + Word(text=" Oh,", start=20.91, end=21.55, speaker=0), + Word(text=" this", start=21.67, end=21.83, speaker=0), + Word(text=" is", start=21.83, end=22.17, speaker=0), + Word(text=" a", start=22.17, end=22.35, speaker=0), + Word(text=" reflector", start=22.35, end=22.89, speaker=0), + Word(text=" translating", start=22.89, end=23.33, speaker=0), + Word(text=" into", start=23.33, end=23.73, speaker=0), + Word(text=" French", start=23.73, end=23.95, speaker=0), + Word(text=" for", start=23.95, end=24.15, speaker=0), + Word(text=" me.", start=24.15, end=24.43, speaker=0), + Word(text=" This", start=27.87, end=28.19, speaker=0), + Word(text=" is", start=28.19, end=28.45, speaker=0), + Word(text=" all", start=28.45, end=28.79, speaker=0), + Word(text=" the", start=28.79, end=29.15, speaker=0), + Word(text=" way,", start=29.15, end=29.15, speaker=0), + Word(text=" please,", start=29.53, end=29.59, speaker=0), + Word(text=" please,", start=29.73, end=29.77, speaker=0), + Word(text=" please,", start=29.77, end=29.83, speaker=0), + Word(text=" please.", start=29.83, end=29.97, speaker=0), + Word(text=" Yeah,", start=29.97, end=30.17, speaker=0), + Word(text=" that's", start=30.25, end=30.33, speaker=0), + Word(text=" all", start=30.33, end=30.49, speaker=0), + Word(text=" it's", start=30.49, end=30.69, speaker=0), + Word(text=" right.", start=30.69, end=30.69, speaker=0), + Word(text=" Right.", start=30.72, end=30.98, speaker=1), + Word(text=" Yeah,", start=31.56, end=31.72, speaker=2), + Word(text=" that's", start=31.86, end=31.98, speaker=2), + Word(text=" right.", start=31.98, end=32.2, speaker=2), + Word(text=" Because", start=32.38, end=32.46, speaker=0), + Word(text=" I", start=32.46, end=32.58, speaker=0), + Word(text=" thought", start=32.58, end=32.78, speaker=0), + Word(text=" I'd", start=32.78, end=33.0, speaker=0), + Word(text=" be", start=33.0, end=33.02, speaker=0), + Word(text=" able", start=33.02, end=33.18, speaker=0), + Word(text=" to", start=33.18, end=33.34, speaker=0), + Word(text=" pull", start=33.34, end=33.52, speaker=0), + Word(text=" out.", start=33.52, end=33.68, speaker=0), + Word(text=" Yeah,", start=33.7, end=33.9, speaker=0), + Word(text=" that", start=33.9, end=34.02, speaker=0), + Word(text=" was", start=34.02, end=34.24, speaker=0), + Word(text=" the", start=34.24, end=34.34, speaker=0), + Word(text=" one", start=34.34, end=34.44, speaker=0), + Word(text=" before", start=34.44, end=34.7, speaker=0), + Word(text=" that.", start=34.7, end=35.24, speaker=0), + Word(text=" Friends,", start=35.84, end=36.46, speaker=0), + Word(text=" if", start=36.64, end=36.7, speaker=0), + Word(text=" you", start=36.7, end=36.7, speaker=0), + Word(text=" have", start=36.7, end=37.24, speaker=0), + Word(text=" tell", start=37.24, end=37.44, speaker=0), + Word(text=" us", start=37.44, end=37.68, speaker=0), + Word(text=" if", start=37.68, end=37.82, speaker=0), + Word(text=" it's", start=37.82, end=38.04, speaker=0), + Word(text=" good,", start=38.04, end=38.58, speaker=0), + Word(text=" exceptionally", start=38.96, end=39.1, speaker=0), + Word(text=" good", start=39.1, end=39.6, speaker=0), + Word(text=" and", start=39.6, end=39.86, speaker=0), + Word(text=" tell", start=39.86, end=40.0, speaker=0), + Word(text=" us", start=40.0, end=40.06, speaker=0), + Word(text=" when", start=40.06, end=40.2, speaker=0), + Word(text=" it's", start=40.2, end=40.34, speaker=0), + Word(text=" exceptionally", start=40.34, end=40.6, speaker=0), + Word(text=" bad.", start=40.6, end=40.94, speaker=0), + Word(text=" We", start=40.96, end=41.26, speaker=0), + Word(text=" don't", start=41.26, end=41.44, speaker=0), + Word(text=" need", start=41.44, end=41.66, speaker=0), + Word(text=" that", start=41.66, end=41.82, speaker=0), + Word(text=" at", start=41.82, end=41.94, speaker=0), + Word(text=" the", start=41.94, end=41.98, speaker=0), + Word(text=" middle", start=41.98, end=42.18, speaker=0), + Word(text=" of", start=42.18, end=42.36, speaker=0), + Word(text=" age.", start=42.36, end=42.7, speaker=0), + Word(text=" Okay,", start=43.26, end=43.44, speaker=0), + Word(text=" yeah,", start=43.68, end=43.76, speaker=0), + Word(text=" that", start=43.78, end=44.3, speaker=0), + Word(text=" sentence", start=44.3, end=44.72, speaker=0), + Word(text=" right", start=44.72, end=45.1, speaker=0), + Word(text=" before.", start=45.1, end=45.56, speaker=0), + Word(text=" it", start=46.08, end=46.36, speaker=0), + Word(text=" realizing", start=46.36, end=47.0, speaker=0), + Word(text=" that", start=47.0, end=47.28, speaker=0), + Word(text=" I", start=47.28, end=47.28, speaker=0), + Word(text=" was", start=47.28, end=47.64, speaker=0), + Word(text=" saying", start=47.64, end=48.06, speaker=0), + Word(text=" that", start=48.06, end=48.44, speaker=0), + Word(text=" it's", start=48.44, end=48.54, speaker=0), + Word(text=" interesting", start=48.54, end=48.78, speaker=0), + Word(text=" that", start=48.78, end=48.96, speaker=0), + Word(text=" it's", start=48.96, end=49.08, speaker=0), + Word(text=" translating", start=49.08, end=49.32, speaker=0), + Word(text=" the", start=49.32, end=49.56, speaker=0), + Word(text=" French", start=49.56, end=49.76, speaker=0), + Word(text=" was", start=49.76, end=50.16, speaker=0), + Word(text=" completely", start=50.16, end=50.4, speaker=0), + Word(text=" wrong.", start=50.4, end=50.7, speaker=0), + ] + ) + + for segment in transcript.as_segments(): + print(segment) diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx index e5de09c8..4fedacb0 100644 --- a/www/app/[domain]/transcripts/topicList.tsx +++ b/www/app/[domain]/transcripts/topicList.tsx @@ -103,7 +103,17 @@ export function TopicList({ /> {activeTopic?.id == topic.id && ( -
{topic.transcript}
+
+ {topic.segments.map((segment, index) => ( +

+ [{formatTime(segment.timestamp)}] Speaker{" "} + {segment.speaker}: {segment.text} +

+ ))} +
)} ))} diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts index 450b3b1c..3e02e26a 100644 --- a/www/app/[domain]/transcripts/webSocketTypes.ts +++ b/www/app/[domain]/transcripts/webSocketTypes.ts @@ -1,9 +1,15 @@ +export type SegmentTopic = { + speaker: number; + start: number; + text: string; +}; + export type Topic = { timestamp: number; title: string; - transcript: string; summary: string; id: string; + segments: SegmentTopic[]; }; export type Transcript = { diff --git a/www/app/api/.openapi-generator/FILES b/www/app/api/.openapi-generator/FILES index 16763a8d..9192cb1f 100644 --- a/www/app/api/.openapi-generator/FILES +++ b/www/app/api/.openapi-generator/FILES @@ -8,6 +8,7 @@ models/GetTranscript.ts models/HTTPValidationError.ts models/PageGetTranscript.ts models/RtcOffer.ts +models/TranscriptSegmentTopic.ts models/TranscriptTopic.ts models/UpdateTranscript.ts models/UserInfo.ts diff --git a/www/app/api/models/TranscriptSegmentTopic.ts b/www/app/api/models/TranscriptSegmentTopic.ts new file mode 100644 index 00000000..73496a67 --- /dev/null +++ b/www/app/api/models/TranscriptSegmentTopic.ts @@ -0,0 +1,88 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from "../runtime"; +/** + * + * @export + * @interface TranscriptSegmentTopic + */ +export interface TranscriptSegmentTopic { + /** + * + * @type {any} + * @memberof TranscriptSegmentTopic + */ + speaker: any | null; + /** + * + * @type {any} + * @memberof TranscriptSegmentTopic + */ + text: any | null; + /** + * + * @type {any} + * @memberof TranscriptSegmentTopic + */ + timestamp: any | null; +} + +/** + * Check if a given object implements the TranscriptSegmentTopic interface. + */ +export function instanceOfTranscriptSegmentTopic(value: object): boolean { + let isInstance = true; + isInstance = isInstance && "speaker" in value; + isInstance = isInstance && "text" in value; + isInstance = isInstance && "timestamp" in value; + + return isInstance; +} + +export function TranscriptSegmentTopicFromJSON( + json: any, +): TranscriptSegmentTopic { + return TranscriptSegmentTopicFromJSONTyped(json, false); +} + +export function TranscriptSegmentTopicFromJSONTyped( + json: any, + ignoreDiscriminator: boolean, +): TranscriptSegmentTopic { + if (json === undefined || json === null) { + return json; + } + return { + speaker: json["speaker"], + text: json["text"], + timestamp: json["timestamp"], + }; +} + +export function TranscriptSegmentTopicToJSON( + value?: TranscriptSegmentTopic | null, +): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + speaker: value.speaker, + text: value.text, + timestamp: value.timestamp, + }; +} diff --git a/www/app/api/models/TranscriptTopic.ts b/www/app/api/models/TranscriptTopic.ts index 8b395374..f22496b2 100644 --- a/www/app/api/models/TranscriptTopic.ts +++ b/www/app/api/models/TranscriptTopic.ts @@ -42,13 +42,13 @@ export interface TranscriptTopic { * @type {any} * @memberof TranscriptTopic */ - transcript?: any | null; + timestamp: any | null; /** * * @type {any} * @memberof TranscriptTopic */ - timestamp: any | null; + segments?: any | null; } /** @@ -78,8 +78,8 @@ export function TranscriptTopicFromJSONTyped( id: !exists(json, "id") ? undefined : json["id"], title: json["title"], summary: json["summary"], - transcript: !exists(json, "transcript") ? undefined : json["transcript"], timestamp: json["timestamp"], + segments: !exists(json, "segments") ? undefined : json["segments"], }; } @@ -94,7 +94,7 @@ export function TranscriptTopicToJSON(value?: TranscriptTopic | null): any { id: value.id, title: value.title, summary: value.summary, - transcript: value.transcript, timestamp: value.timestamp, + segments: value.segments, }; } diff --git a/www/app/api/models/index.ts b/www/app/api/models/index.ts index 99874641..6a7d7318 100644 --- a/www/app/api/models/index.ts +++ b/www/app/api/models/index.ts @@ -7,6 +7,7 @@ export * from "./GetTranscript"; export * from "./HTTPValidationError"; export * from "./PageGetTranscript"; export * from "./RtcOffer"; +export * from "./TranscriptSegmentTopic"; export * from "./TranscriptTopic"; export * from "./UpdateTranscript"; export * from "./UserInfo"; From 00eb9bbf3c5ccdb8f7f3f9364ae952d821897759 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 20 Oct 2023 16:06:35 +0200 Subject: [PATCH 03/41] server: improve split algorithm --- server/reflector/processors/types.py | 35 ++++++++++------------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 686c5785..ba0cccf9 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -1,10 +1,13 @@ import io +import re import tempfile from pathlib import Path from profanityfilter import ProfanityFilter from pydantic import BaseModel, PrivateAttr +PUNC_RE = re.compile(r"[.;:?!…]") + profanity_filter = ProfanityFilter() profanity_filter.set_censor("*") @@ -106,15 +109,14 @@ class Transcript(BaseModel): ] return Transcript(text=self.text, translation=self.translation, words=words) - def as_segments(self): + def as_segments(self) -> list[TranscriptSegment]: # from a list of word, create a list of segments # join the word that are less than 2 seconds apart # but separate if the speaker changes, or if the punctuation is a . , ; : ? ! segments = [] current_segment = None - last_word = None - BLANK_TIME_SECS = 2 - MAX_SEGMENT_LENGTH = 80 + MAX_SEGMENT_LENGTH = 120 + for word in self.words: if current_segment is None: current_segment = TranscriptSegment( @@ -123,27 +125,14 @@ class Transcript(BaseModel): speaker=word.speaker, ) continue - is_blank = False - if last_word: - is_blank = word.start - last_word.end > BLANK_TIME_SECS - if ( - word.speaker != current_segment.speaker - or ( - word.text in ".;:?!…" - and len(current_segment.text) > MAX_SEGMENT_LENGTH - ) - or is_blank + current_segment.text += word.text + + have_punc = PUNC_RE.search(word.text) + if word.speaker != current_segment.speaker or ( + have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH) ): - # check which condition triggered segments.append(current_segment) - current_segment = TranscriptSegment( - text=word.text, - start=word.start, - speaker=word.speaker, - ) - else: - current_segment.text += word.text - last_word = word + current_segment = None if current_segment: segments.append(current_segment) return segments From 21e408b32391f1cefd69401e40ba558e421cdef8 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 20 Oct 2023 16:07:12 +0200 Subject: [PATCH 04/41] server: include transcripts words in database, but keep back compatible api --- server/reflector/views/transcripts.py | 80 +++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 0a068c17..1d9fd4bd 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -17,6 +17,8 @@ from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field from reflector.db import database, transcripts from reflector.logger import logger +from reflector.processors.types import Transcript as ProcessorTranscript +from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform from starlette.concurrency import run_in_threadpool @@ -60,7 +62,8 @@ class TranscriptTopic(BaseModel): title: str summary: str timestamp: float - segments: list[TranscriptSegmentTopic] = [] + text: str | None = None + words: list[ProcessorWord] = [] class TranscriptFinalShortSummary(BaseModel): @@ -304,6 +307,53 @@ async def transcripts_create( # ============================================================== +class GetTranscriptSegmentTopic(BaseModel): + text: str + start: float + speaker: int + + +class GetTranscriptTopic(BaseModel): + title: str + summary: str + timestamp: float + text: str + segments: list[GetTranscriptSegmentTopic] = [] + + @classmethod + def from_transcript_topic(cls, topic: TranscriptTopic): + if not topic.words: + # In previous version, words were missing + # Just output a segment with speaker 0 + text = topic.text + segments = [ + GetTranscriptSegmentTopic( + text=topic.text, + start=topic.timestamp, + speaker=0, + ) + ] + else: + # New versions include words + transcript = ProcessorTranscript(words=topic.words) + text = transcript.text + segments = [ + GetTranscriptSegmentTopic( + text=segment.text, + start=segment.start, + speaker=segment.speaker, + ) + for segment in transcript.as_segments() + ] + return cls( + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + text=text, + segments=segments, + ) + + @router.get("/transcripts/{transcript_id}", response_model=GetTranscript) async def transcript_get( transcript_id: str, @@ -412,7 +462,10 @@ async def transcript_get_audio_waveform( return transcript.audio_waveform -@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) +@router.get( + "/transcripts/{transcript_id}/topics", + response_model=list[GetTranscriptTopic], +) async def transcript_get_topics( transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], @@ -421,7 +474,11 @@ async def transcript_get_topics( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - return transcript.topics + + # convert to GetTranscriptTopic + return [ + GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics + ] @router.get("/transcripts/{transcript_id}/events") @@ -498,6 +555,13 @@ async def transcript_events_websocket( async def handle_rtc_event(event: PipelineEvent, args, data): + try: + return await handle_rtc_event_once(event, args, data) + except Exception: + logger.exception("Error handling RTC event") + + +async def handle_rtc_event_once(event: PipelineEvent, args, data): # OFC the current implementation is not good, # but it's just a POC before persistence. It won't query the # transcript from the database for each event. @@ -530,14 +594,8 @@ async def handle_rtc_event(event: PipelineEvent, args, data): title=data.title, summary=data.summary, timestamp=data.timestamp, - segments=[ - TranscriptSegmentTopic( - speaker=segment.speaker, - text=segment.text, - timestamp=segment.start, - ) - for segment in data.transcript.as_segments() - ], + text=data.transcript.text, + words=data.transcript.words, ) resp = transcript.add_event(event=event, data=topic) transcript.upsert_topic(topic) From 01d7add6cc32d941b9e4e4e4a71616ad86e46531 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 20 Oct 2023 16:07:29 +0200 Subject: [PATCH 05/41] www: update openapi and display --- www/app/[domain]/transcripts/topicList.tsx | 29 +++-- .../[domain]/transcripts/webSocketTypes.ts | 1 + www/app/api/.openapi-generator/FILES | 4 +- .../api/models/GetTranscriptSegmentTopic.ts | 88 +++++++++++++++ www/app/api/models/GetTranscriptTopic.ts | 103 ++++++++++++++++++ www/app/api/models/index.ts | 4 +- 6 files changed, 216 insertions(+), 13 deletions(-) create mode 100644 www/app/api/models/GetTranscriptSegmentTopic.ts create mode 100644 www/app/api/models/GetTranscriptTopic.ts diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx index 4fedacb0..d10cb13f 100644 --- a/www/app/[domain]/transcripts/topicList.tsx +++ b/www/app/[domain]/transcripts/topicList.tsx @@ -104,15 +104,26 @@ export function TopicList({ {activeTopic?.id == topic.id && (
- {topic.segments.map((segment, index) => ( -

- [{formatTime(segment.timestamp)}] Speaker{" "} - {segment.speaker}: {segment.text} -

- ))} + {topic.segments ? ( + <> + {topic.segments.map((segment, index: number) => ( +

+ + [{formatTime(segment.start)}] + + +  Speaker {segment.speaker} + + {segment.text} +

+ ))} + + ) : ( + <>{topic.text} + )}
)} diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts index 3e02e26a..abc67b33 100644 --- a/www/app/[domain]/transcripts/webSocketTypes.ts +++ b/www/app/[domain]/transcripts/webSocketTypes.ts @@ -9,6 +9,7 @@ export type Topic = { title: string; summary: string; id: string; + text: string; segments: SegmentTopic[]; }; diff --git a/www/app/api/.openapi-generator/FILES b/www/app/api/.openapi-generator/FILES index 9192cb1f..532a6a16 100644 --- a/www/app/api/.openapi-generator/FILES +++ b/www/app/api/.openapi-generator/FILES @@ -5,11 +5,11 @@ models/AudioWaveform.ts models/CreateTranscript.ts models/DeletionStatus.ts models/GetTranscript.ts +models/GetTranscriptSegmentTopic.ts +models/GetTranscriptTopic.ts models/HTTPValidationError.ts models/PageGetTranscript.ts models/RtcOffer.ts -models/TranscriptSegmentTopic.ts -models/TranscriptTopic.ts models/UpdateTranscript.ts models/UserInfo.ts models/ValidationError.ts diff --git a/www/app/api/models/GetTranscriptSegmentTopic.ts b/www/app/api/models/GetTranscriptSegmentTopic.ts new file mode 100644 index 00000000..cc2049b1 --- /dev/null +++ b/www/app/api/models/GetTranscriptSegmentTopic.ts @@ -0,0 +1,88 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from "../runtime"; +/** + * + * @export + * @interface GetTranscriptSegmentTopic + */ +export interface GetTranscriptSegmentTopic { + /** + * + * @type {any} + * @memberof GetTranscriptSegmentTopic + */ + text: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptSegmentTopic + */ + start: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptSegmentTopic + */ + speaker: any | null; +} + +/** + * Check if a given object implements the GetTranscriptSegmentTopic interface. + */ +export function instanceOfGetTranscriptSegmentTopic(value: object): boolean { + let isInstance = true; + isInstance = isInstance && "text" in value; + isInstance = isInstance && "start" in value; + isInstance = isInstance && "speaker" in value; + + return isInstance; +} + +export function GetTranscriptSegmentTopicFromJSON( + json: any, +): GetTranscriptSegmentTopic { + return GetTranscriptSegmentTopicFromJSONTyped(json, false); +} + +export function GetTranscriptSegmentTopicFromJSONTyped( + json: any, + ignoreDiscriminator: boolean, +): GetTranscriptSegmentTopic { + if (json === undefined || json === null) { + return json; + } + return { + text: json["text"], + start: json["start"], + speaker: json["speaker"], + }; +} + +export function GetTranscriptSegmentTopicToJSON( + value?: GetTranscriptSegmentTopic | null, +): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + text: value.text, + start: value.start, + speaker: value.speaker, + }; +} diff --git a/www/app/api/models/GetTranscriptTopic.ts b/www/app/api/models/GetTranscriptTopic.ts new file mode 100644 index 00000000..7a7d4c90 --- /dev/null +++ b/www/app/api/models/GetTranscriptTopic.ts @@ -0,0 +1,103 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from "../runtime"; +/** + * + * @export + * @interface GetTranscriptTopic + */ +export interface GetTranscriptTopic { + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + title: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + summary: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + timestamp: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + text: any | null; + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + segments?: any | null; +} + +/** + * Check if a given object implements the GetTranscriptTopic interface. + */ +export function instanceOfGetTranscriptTopic(value: object): boolean { + let isInstance = true; + isInstance = isInstance && "title" in value; + isInstance = isInstance && "summary" in value; + isInstance = isInstance && "timestamp" in value; + isInstance = isInstance && "text" in value; + + return isInstance; +} + +export function GetTranscriptTopicFromJSON(json: any): GetTranscriptTopic { + return GetTranscriptTopicFromJSONTyped(json, false); +} + +export function GetTranscriptTopicFromJSONTyped( + json: any, + ignoreDiscriminator: boolean, +): GetTranscriptTopic { + if (json === undefined || json === null) { + return json; + } + return { + title: json["title"], + summary: json["summary"], + timestamp: json["timestamp"], + text: json["text"], + segments: !exists(json, "segments") ? undefined : json["segments"], + }; +} + +export function GetTranscriptTopicToJSON( + value?: GetTranscriptTopic | null, +): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + title: value.title, + summary: value.summary, + timestamp: value.timestamp, + text: value.text, + segments: value.segments, + }; +} diff --git a/www/app/api/models/index.ts b/www/app/api/models/index.ts index 6a7d7318..9e691b09 100644 --- a/www/app/api/models/index.ts +++ b/www/app/api/models/index.ts @@ -4,11 +4,11 @@ export * from "./AudioWaveform"; export * from "./CreateTranscript"; export * from "./DeletionStatus"; export * from "./GetTranscript"; +export * from "./GetTranscriptSegmentTopic"; +export * from "./GetTranscriptTopic"; export * from "./HTTPValidationError"; export * from "./PageGetTranscript"; export * from "./RtcOffer"; -export * from "./TranscriptSegmentTopic"; -export * from "./TranscriptTopic"; export * from "./UpdateTranscript"; export * from "./UserInfo"; export * from "./ValidationError"; From f4cffc0e66c4d42aa69644b1456a0b63ca3ff8e7 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 20 Oct 2023 16:14:30 +0200 Subject: [PATCH 06/41] server: add tests on segmentation and fix issue with speaker --- server/reflector/processors/types.py | 20 ++++++++++++++++--- .../test_processor_transcript_segment.py | 19 ++++++++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index ba0cccf9..d2c32d17 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -125,16 +125,30 @@ class Transcript(BaseModel): speaker=word.speaker, ) continue + + # If the word is attach to another speaker, push the current segment + # and start a new one + if word.speaker != current_segment.speaker: + segments.append(current_segment) + current_segment = TranscriptSegment( + text=word.text, + start=word.start, + speaker=word.speaker, + ) + continue + + # if the word is the end of a sentence, and we have enough content, + # add the word to the current segment and push it current_segment.text += word.text have_punc = PUNC_RE.search(word.text) - if word.speaker != current_segment.speaker or ( - have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH) - ): + if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH): segments.append(current_segment) current_segment = None + if current_segment: segments.append(current_segment) + return segments diff --git a/server/tests/test_processor_transcript_segment.py b/server/tests/test_processor_transcript_segment.py index 3bb6182f..6fde0dd1 100644 --- a/server/tests/test_processor_transcript_segment.py +++ b/server/tests/test_processor_transcript_segment.py @@ -142,5 +142,20 @@ def test_processor_transcript_segment(): ] ) - for segment in transcript.as_segments(): - print(segment) + segments = transcript.as_segments() + assert len(segments) == 7 + + # check speaker order + assert segments[0].speaker == 0 + assert segments[1].speaker == 0 + assert segments[2].speaker == 0 + assert segments[3].speaker == 1 + assert segments[4].speaker == 2 + assert segments[5].speaker == 0 + assert segments[6].speaker == 0 + + # check the timing (first entry, and first of others speakers) + assert segments[0].start == 5.12 + assert segments[3].start == 30.72 + assert segments[4].start == 31.56 + assert segments[5].start == 32.38 From 5f00673ebddd91a2105d44871b597de7d941a988 Mon Sep 17 00:00:00 2001 From: Koper Date: Wed, 25 Oct 2023 15:13:52 +0100 Subject: [PATCH 07/41] Implemented speaker color functions --- www/app/[domain]/transcripts/topicList.tsx | 16 +- www/app/[domain]/transcripts/useWebSockets.ts | 167 ++++++++++++++++-- www/app/lib/utils.ts | 122 +++++++++++++ 3 files changed, 286 insertions(+), 19 deletions(-) diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx index d10cb13f..ef4a2889 100644 --- a/www/app/[domain]/transcripts/topicList.tsx +++ b/www/app/[domain]/transcripts/topicList.tsx @@ -7,6 +7,7 @@ import { import { formatTime } from "../../lib/time"; import ScrollToBottom from "./scrollToBottom"; import { Topic } from "./webSocketTypes"; +import { generateHighContrastColor } from "../lib/utils"; type TopicListProps = { topics: Topic[]; @@ -114,9 +115,18 @@ export function TopicList({ [{formatTime(segment.start)}] - -  Speaker {segment.speaker} - + + {" "} + (Speaker {segment.speaker}): + {" "} {segment.text}

))} diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index 6bd7bf48..8196749e 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -56,38 +56,116 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { timestamp: 10, summary: "This is test topic 1", title: "Topic 1: Introduction to Quantum Mechanics", - transcript: - "A brief overview of quantum mechanics and its principles.", + text: "A brief overview of quantum mechanics and its principles.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + , + { + speaker: 3, + start: 90, + text: "This is the third speaker", + }, + { + speaker: 4, + start: 90, + text: "This is the fourth speaker", + }, + { + speaker: 5, + start: 123, + text: "This is the fifth speaker", + }, + { + speaker: 6, + start: 300, + text: "This is the sixth speaker", + }, + ], }, { id: "2", timestamp: 20, summary: "This is test topic 2", title: "Topic 2: Machine Learning Algorithms", - transcript: - "Understanding the different types of machine learning algorithms.", + text: "Understanding the different types of machine learning algorithms.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "3", timestamp: 30, summary: "This is test topic 3", title: "Topic 3: Mental Health Awareness", - transcript: "Ways to improve mental health and reduce stigma.", + text: "Ways to improve mental health and reduce stigma.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "4", timestamp: 40, summary: "This is test topic 4", title: "Topic 4: Basics of Productivity", - transcript: "Tips and tricks to increase daily productivity.", + text: "Tips and tricks to increase daily productivity.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "5", timestamp: 50, summary: "This is test topic 5", title: "Topic 5: Future of Aviation", - transcript: - "Exploring the advancements and possibilities in aviation.", + text: "Exploring the advancements and possibilities in aviation.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, ]); @@ -104,8 +182,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 1", title: "Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.", - transcript: - "A brief overview of quantum mechanics and its principles.", + text: "A brief overview of quantum mechanics and its principles.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "2", @@ -113,8 +202,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 2", title: "Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.", - transcript: - "Understanding the different types of machine learning algorithms.", + text: "Understanding the different types of machine learning algorithms.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "3", @@ -122,7 +222,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 3", title: "Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.", - transcript: "Ways to improve mental health and reduce stigma.", + text: "Ways to improve mental health and reduce stigma.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "4", @@ -130,7 +242,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 4", title: "Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.", - transcript: "Tips and tricks to increase daily productivity.", + text: "Tips and tricks to increase daily productivity.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, { id: "5", @@ -138,8 +262,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 5", title: "Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.", - transcript: - "Exploring the advancements and possibilities in aviation.", + text: "Exploring the advancements and possibilities in aviation.", + segments: [ + { + speaker: 1, + start: 0, + text: "This is the transcription of an example title", + }, + { + speaker: 2, + start: 10, + text: "This is the second speaker", + }, + ], }, ]); diff --git a/www/app/lib/utils.ts b/www/app/lib/utils.ts index db775f07..43ed8bf8 100644 --- a/www/app/lib/utils.ts +++ b/www/app/lib/utils.ts @@ -1,3 +1,125 @@ export function isDevelopment() { return process.env.NEXT_PUBLIC_ENV === "development"; } + +// Function to calculate WCAG contrast ratio +export const getContrastRatio = ( + foreground: [number, number, number], + background: [number, number, number], +) => { + const [r1, g1, b1] = foreground; + const [r2, g2, b2] = background; + + const lum1 = + 0.2126 * Math.pow(r1 / 255, 2.2) + + 0.7152 * Math.pow(g1 / 255, 2.2) + + 0.0722 * Math.pow(b1 / 255, 2.2); + const lum2 = + 0.2126 * Math.pow(r2 / 255, 2.2) + + 0.7152 * Math.pow(g2 / 255, 2.2) + + 0.0722 * Math.pow(b2 / 255, 2.2); + + return (Math.max(lum1, lum2) + 0.05) / (Math.min(lum1, lum2) + 0.05); +}; + +// Function to hash string into 32-bit integer +// 🔴 DO NOT USE FOR CRYPTOGRAPHY PURPOSES 🔴 + +export function murmurhash3_32_gc(key: string, seed: number = 0) { + let remainder, bytes, h1, h1b, c1, c2, k1, i; + + remainder = key.length & 3; // key.length % 4 + bytes = key.length - remainder; + h1 = seed; + c1 = 0xcc9e2d51; + c2 = 0x1b873593; + i = 0; + + while (i < bytes) { + k1 = + (key.charCodeAt(i) & 0xff) | + ((key.charCodeAt(++i) & 0xff) << 8) | + ((key.charCodeAt(++i) & 0xff) << 16) | + ((key.charCodeAt(++i) & 0xff) << 24); + + ++i; + + k1 = + ((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & 0xffffffff; + k1 = (k1 << 15) | (k1 >>> 17); + k1 = + ((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & 0xffffffff; + + h1 ^= k1; + h1 = (h1 << 13) | (h1 >>> 19); + h1b = + ((h1 & 0xffff) * 5 + ((((h1 >>> 16) * 5) & 0xffff) << 16)) & 0xffffffff; + h1 = (h1b & 0xffff) + 0x6b64 + ((((h1b >>> 16) + 0xe654) & 0xffff) << 16); + } + + k1 = 0; + + switch (remainder) { + case 3: + k1 ^= (key.charCodeAt(i + 2) & 0xff) << 16; + case 2: + k1 ^= (key.charCodeAt(i + 1) & 0xff) << 8; + case 1: + k1 ^= key.charCodeAt(i) & 0xff; + + k1 = + ((k1 & 0xffff) * c1 + ((((k1 >>> 16) * c1) & 0xffff) << 16)) & + 0xffffffff; + k1 = (k1 << 15) | (k1 >>> 17); + k1 = + ((k1 & 0xffff) * c2 + ((((k1 >>> 16) * c2) & 0xffff) << 16)) & + 0xffffffff; + h1 ^= k1; + } + + h1 ^= key.length; + + h1 ^= h1 >>> 16; + h1 = + ((h1 & 0xffff) * 0x85ebca6b + + ((((h1 >>> 16) * 0x85ebca6b) & 0xffff) << 16)) & + 0xffffffff; + h1 ^= h1 >>> 13; + h1 = + ((h1 & 0xffff) * 0xc2b2ae35 + + ((((h1 >>> 16) * 0xc2b2ae35) & 0xffff) << 16)) & + 0xffffffff; + h1 ^= h1 >>> 16; + + return h1 >>> 0; +} + +// Generates a color that is guaranteed to have high contrast with the given background color (optional) + +export const generateHighContrastColor = ( + name: string, + backgroundColor: [number, number, number] | null = null, +) => { + const hash = murmurhash3_32_gc(name); + console.log(name, hash); + + let red = (hash & 0xff0000) >> 16; + let green = (hash & 0x00ff00) >> 8; + let blue = hash & 0x0000ff; + + const getCssColor = (red: number, green: number, blue: number) => + `rgb(${red}, ${green}, ${blue})`; + + if (!backgroundColor) return getCssColor(red, green, blue); + + const contrast = getContrastRatio([red, green, blue], backgroundColor); + + // Adjust the color to achieve better contrast if necessary (WCAG recommends at least 4.5:1 for text) + if (contrast < 4.5) { + red = Math.abs(255 - red); + green = Math.abs(255 - green); + blue = Math.abs(255 - blue); + } + + return getCssColor(red, green, blue); +}; From 16a857927286cf82c9c766d4cec3d5ef9d4fdd5a Mon Sep 17 00:00:00 2001 From: Koper Date: Wed, 25 Oct 2023 15:17:34 +0100 Subject: [PATCH 08/41] Removed console.log --- www/app/lib/utils.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/www/app/lib/utils.ts b/www/app/lib/utils.ts index 43ed8bf8..6b72ddea 100644 --- a/www/app/lib/utils.ts +++ b/www/app/lib/utils.ts @@ -101,8 +101,6 @@ export const generateHighContrastColor = ( backgroundColor: [number, number, number] | null = null, ) => { const hash = murmurhash3_32_gc(name); - console.log(name, hash); - let red = (hash & 0xff0000) >> 16; let green = (hash & 0x00ff00) >> 8; let blue = hash & 0x0000ff; From 8bebb2a76907bb9095bd2d3d244e7f92afd8e38d Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 20 Oct 2023 18:00:59 +0200 Subject: [PATCH 09/41] server: start moving to an external celery task --- server/poetry.lock | 210 +++++++++++++++++++++- server/pyproject.toml | 1 + server/reflector/settings.py | 4 + server/reflector/tasks/boot.py | 2 + server/reflector/tasks/post_transcript.py | 170 ++++++++++++++++++ server/reflector/tasks/worker.py | 6 + 6 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 server/reflector/tasks/boot.py create mode 100644 server/reflector/tasks/post_transcript.py create mode 100644 server/reflector/tasks/worker.py diff --git a/server/poetry.lock b/server/poetry.lock index 330c23e3..0df46097 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -308,6 +308,20 @@ typing-extensions = ">=4" [package.extras] tz = ["python-dateutil"] +[[package]] +name = "amqp" +version = "5.1.1" +description = "Low-level AMQP client for Python (fork of amqplib)." +optional = false +python-versions = ">=3.6" +files = [ + {file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"}, + {file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"}, +] + +[package.dependencies] +vine = ">=5.0.0" + [[package]] name = "annotated-types" version = "0.6.0" @@ -474,6 +488,17 @@ files = [ {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, ] +[[package]] +name = "billiard" +version = "4.1.0" +description = "Python multiprocessing fork with improvements and bugfixes" +optional = false +python-versions = ">=3.7" +files = [ + {file = "billiard-4.1.0-py3-none-any.whl", hash = "sha256:0f50d6be051c6b2b75bfbc8bfd85af195c5739c281d3f5b86a5640c65563614a"}, + {file = "billiard-4.1.0.tar.gz", hash = "sha256:1ad2eeae8e28053d729ba3373d34d9d6e210f6e4d8bf0a9c64f92bd053f1edf5"}, +] + [[package]] name = "black" version = "23.9.1" @@ -556,6 +581,61 @@ urllib3 = ">=1.25.4,<1.27" [package.extras] crt = ["awscrt (==0.16.26)"] +[[package]] +name = "celery" +version = "5.3.4" +description = "Distributed Task Queue." +optional = false +python-versions = ">=3.8" +files = [ + {file = "celery-5.3.4-py3-none-any.whl", hash = "sha256:1e6ed40af72695464ce98ca2c201ad0ef8fd192246f6c9eac8bba343b980ad34"}, + {file = "celery-5.3.4.tar.gz", hash = "sha256:9023df6a8962da79eb30c0c84d5f4863d9793a466354cc931d7f72423996de28"}, +] + +[package.dependencies] +billiard = ">=4.1.0,<5.0" +click = ">=8.1.2,<9.0" +click-didyoumean = ">=0.3.0" +click-plugins = ">=1.1.1" +click-repl = ">=0.2.0" +kombu = ">=5.3.2,<6.0" +python-dateutil = ">=2.8.2" +tzdata = ">=2022.7" +vine = ">=5.0.0,<6.0" + +[package.extras] +arangodb = ["pyArango (>=2.0.2)"] +auth = ["cryptography (==41.0.3)"] +azureblockblob = ["azure-storage-blob (>=12.15.0)"] +brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"] +cassandra = ["cassandra-driver (>=3.25.0,<4)"] +consul = ["python-consul2 (==0.1.5)"] +cosmosdbsql = ["pydocumentdb (==2.3.5)"] +couchbase = ["couchbase (>=3.0.0)"] +couchdb = ["pycouchdb (==1.14.2)"] +django = ["Django (>=2.2.28)"] +dynamodb = ["boto3 (>=1.26.143)"] +elasticsearch = ["elasticsearch (<8.0)"] +eventlet = ["eventlet (>=0.32.0)"] +gevent = ["gevent (>=1.5.0)"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +memcache = ["pylibmc (==1.6.3)"] +mongodb = ["pymongo[srv] (>=4.0.2)"] +msgpack = ["msgpack (==1.0.5)"] +pymemcache = ["python-memcached (==1.59)"] +pyro = ["pyro4 (==4.82)"] +pytest = ["pytest-celery (==0.0.0)"] +redis = ["redis (>=4.5.2,!=4.5.5,<5.0.0)"] +s3 = ["boto3 (>=1.26.143)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +solar = ["ephem (==4.1.4)"] +sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] +sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=1.3.1)"] +zstd = ["zstandard (==0.21.0)"] + [[package]] name = "certifi" version = "2023.7.22" @@ -744,6 +824,55 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "click-didyoumean" +version = "0.3.0" +description = "Enables git-like *did-you-mean* feature in click" +optional = false +python-versions = ">=3.6.2,<4.0.0" +files = [ + {file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"}, + {file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"}, +] + +[package.dependencies] +click = ">=7" + +[[package]] +name = "click-plugins" +version = "1.1.1" +description = "An extension module for click to enable registering CLI commands via setuptools entry-points." +optional = false +python-versions = "*" +files = [ + {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"}, + {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"}, +] + +[package.dependencies] +click = ">=4.0" + +[package.extras] +dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"] + +[[package]] +name = "click-repl" +version = "0.3.0" +description = "REPL plugin for Click" +optional = false +python-versions = ">=3.6" +files = [ + {file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"}, + {file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"}, +] + +[package.dependencies] +click = ">=7.0" +prompt-toolkit = ">=3.0.36" + +[package.extras] +testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] + [[package]] name = "colorama" version = "0.4.6" @@ -1624,6 +1753,38 @@ files = [ cryptography = ">=3.4" deprecated = "*" +[[package]] +name = "kombu" +version = "5.3.2" +description = "Messaging library for Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "kombu-5.3.2-py3-none-any.whl", hash = "sha256:b753c9cfc9b1e976e637a7cbc1a65d446a22e45546cd996ea28f932082b7dc9e"}, + {file = "kombu-5.3.2.tar.gz", hash = "sha256:0ba213f630a2cb2772728aef56ac6883dc3a2f13435e10048f6e97d48506dbbd"}, +] + +[package.dependencies] +amqp = ">=5.1.1,<6.0.0" +vine = "*" + +[package.extras] +azureservicebus = ["azure-servicebus (>=7.10.0)"] +azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"] +confluentkafka = ["confluent-kafka (==2.1.1)"] +consul = ["python-consul2"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +mongodb = ["pymongo (>=4.1.1)"] +msgpack = ["msgpack"] +pyro = ["pyro4"] +qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] +redis = ["redis (>=4.5.2)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] +sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=2.8.0)"] + [[package]] name = "levenshtein" version = "0.21.1" @@ -2151,6 +2312,20 @@ files = [ fastapi = ">=0.38.1,<1.0.0" prometheus-client = ">=0.8.0,<1.0.0" +[[package]] +name = "prompt-toolkit" +version = "3.0.39" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "protobuf" version = "4.24.4" @@ -3438,6 +3613,17 @@ files = [ {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, ] +[[package]] +name = "tzdata" +version = "2023.3" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"}, + {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, +] + [[package]] name = "urllib3" version = "1.26.17" @@ -3523,6 +3709,17 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"] +[[package]] +name = "vine" +version = "5.0.0" +description = "Promises, promises, promises." +optional = false +python-versions = ">=3.6" +files = [ + {file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"}, + {file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"}, +] + [[package]] name = "watchfiles" version = "0.20.0" @@ -3557,6 +3754,17 @@ files = [ [package.dependencies] anyio = ">=3.0.0" +[[package]] +name = "wcwidth" +version = "0.2.8" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"}, + {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"}, +] + [[package]] name = "websockets" version = "11.0.3" @@ -3838,4 +4046,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67" +content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c" diff --git a/server/pyproject.toml b/server/pyproject.toml index e3b44774..7b1b7936 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -33,6 +33,7 @@ prometheus-fastapi-instrumentator = "^6.1.0" sentencepiece = "^0.1.99" protobuf = "^4.24.3" profanityfilter = "^2.0.6" +celery = "^5.3.4" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/settings.py b/server/reflector/settings.py index e0ffd826..1503948a 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -113,5 +113,9 @@ class Settings(BaseSettings): # Min transcript length to generate topic + summary MIN_TRANSCRIPT_LENGTH: int = 750 + # Celery + CELERY_BROKER_URL: str = "redis://localhost:6379/1" + CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" + settings = Settings() diff --git a/server/reflector/tasks/boot.py b/server/reflector/tasks/boot.py new file mode 100644 index 00000000..88cc2d6f --- /dev/null +++ b/server/reflector/tasks/boot.py @@ -0,0 +1,2 @@ +import reflector.tasks.post_transcript # noqa +import reflector.tasks.worker # noqa diff --git a/server/reflector/tasks/post_transcript.py b/server/reflector/tasks/post_transcript.py new file mode 100644 index 00000000..1cfbc664 --- /dev/null +++ b/server/reflector/tasks/post_transcript.py @@ -0,0 +1,170 @@ +from reflector.logger import logger +from reflector.processors import ( + Pipeline, + Processor, + TranscriptFinalLongSummaryProcessor, + TranscriptFinalShortSummaryProcessor, + TranscriptFinalTitleProcessor, +) +from reflector.processors.base import BroadcastProcessor +from reflector.processors.types import ( + FinalLongSummary, + FinalShortSummary, + FinalTitle, + TitleSummary, +) +from reflector.processors.types import Transcript as ProcessorTranscript +from reflector.tasks.worker import celery +from reflector.views.rtc_offer import PipelineEvent, TranscriptionContext +from reflector.views.transcripts import Transcript, transcripts_controller + + +class TranscriptAudioDiarizationProcessor(Processor): + INPUT_TYPE = Transcript + OUTPUT_TYPE = TitleSummary + + async def _push(self, data: Transcript): + # Gather diarization data + diarization = [ + {"start": 0.0, "stop": 4.9, "speaker": 2}, + {"start": 5.6, "stop": 6.7, "speaker": 2}, + {"start": 7.3, "stop": 8.9, "speaker": 2}, + {"start": 7.3, "stop": 7.9, "speaker": 0}, + {"start": 9.4, "stop": 11.2, "speaker": 2}, + {"start": 9.7, "stop": 10.0, "speaker": 0}, + {"start": 10.0, "stop": 10.1, "speaker": 0}, + {"start": 11.7, "stop": 16.1, "speaker": 2}, + {"start": 11.8, "stop": 12.1, "speaker": 1}, + {"start": 16.4, "stop": 21.0, "speaker": 2}, + {"start": 21.1, "stop": 22.6, "speaker": 2}, + {"start": 24.7, "stop": 31.9, "speaker": 2}, + {"start": 32.0, "stop": 32.8, "speaker": 1}, + {"start": 33.4, "stop": 37.8, "speaker": 2}, + {"start": 37.9, "stop": 40.3, "speaker": 0}, + {"start": 39.2, "stop": 40.4, "speaker": 2}, + {"start": 40.7, "stop": 41.4, "speaker": 0}, + {"start": 41.6, "stop": 45.7, "speaker": 2}, + {"start": 46.4, "stop": 53.1, "speaker": 2}, + {"start": 53.6, "stop": 56.5, "speaker": 2}, + {"start": 54.9, "stop": 75.4, "speaker": 1}, + {"start": 57.3, "stop": 58.0, "speaker": 2}, + {"start": 65.7, "stop": 66.0, "speaker": 2}, + {"start": 75.8, "stop": 78.8, "speaker": 1}, + {"start": 79.0, "stop": 82.6, "speaker": 1}, + {"start": 83.2, "stop": 83.3, "speaker": 1}, + {"start": 84.5, "stop": 94.3, "speaker": 1}, + {"start": 95.1, "stop": 100.7, "speaker": 1}, + {"start": 100.7, "stop": 102.0, "speaker": 0}, + {"start": 100.7, "stop": 101.8, "speaker": 1}, + {"start": 102.0, "stop": 103.0, "speaker": 1}, + {"start": 103.0, "stop": 103.7, "speaker": 0}, + {"start": 103.7, "stop": 103.8, "speaker": 1}, + {"start": 103.8, "stop": 113.9, "speaker": 0}, + {"start": 114.7, "stop": 117.0, "speaker": 0}, + {"start": 117.0, "stop": 117.4, "speaker": 1}, + ] + + # now reapply speaker to topics (if any) + # topics is a list[BaseModel] with an attribute words + # words is a list[BaseModel] with text, start and speaker attribute + + # mutate in place + for topic in data.topics: + for word in topic.words: + for d in diarization: + if d["start"] <= word.start <= d["stop"]: + word.speaker = d["speaker"] + + topics = data.topics[:] + + await transcripts_controller.update( + data, + { + "topics": [topic.model_dump(mode="json") for topic in data.topics], + }, + ) + + # emit them + for topic in topics: + transcript = ProcessorTranscript(words=topic.words) + await self.emit( + TitleSummary( + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=0, + transcript=transcript, + ) + ) + + +@celery.task(name="post_transcript") +async def post_transcript_pipeline(transcript_id: str): + # get transcript + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + logger.error("Transcript not found", transcript_id=transcript_id) + return + + ctx = TranscriptionContext(logger=logger.bind(transcript_id=transcript_id)) + event_callback = None + event_callback_args = None + + async def on_final_short_summary(summary: FinalShortSummary): + ctx.logger.info("FinalShortSummary", final_short_summary=summary) + + # send to callback (eg. websocket) + if event_callback: + await event_callback( + event=PipelineEvent.FINAL_SHORT_SUMMARY, + args=event_callback_args, + data=summary, + ) + + async def on_final_long_summary(summary: FinalLongSummary): + ctx.logger.info("FinalLongSummary", final_summary=summary) + + # send to callback (eg. websocket) + if event_callback: + await event_callback( + event=PipelineEvent.FINAL_LONG_SUMMARY, + args=event_callback_args, + data=summary, + ) + + async def on_final_title(title: FinalTitle): + ctx.logger.info("FinalTitle", final_title=title) + + # send to callback (eg. websocket) + if event_callback: + await event_callback( + event=PipelineEvent.FINAL_TITLE, + args=event_callback_args, + data=title, + ) + + ctx.logger.info("Starting pipeline (diarization)") + ctx.pipeline = Pipeline( + TranscriptAudioDiarizationProcessor(), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(), + TranscriptFinalLongSummaryProcessor.as_threaded(), + TranscriptFinalShortSummaryProcessor.as_threaded(), + ] + ), + ) + + await ctx.pipeline.push(transcript) + await ctx.pipeline.flush() + + +if __name__ == "__main__": + import argparse + import asyncio + + parser = argparse.ArgumentParser() + parser.add_argument("transcript_id", type=str) + args = parser.parse_args() + + asyncio.run(post_transcript_pipeline(args.transcript_id)) diff --git a/server/reflector/tasks/worker.py b/server/reflector/tasks/worker.py new file mode 100644 index 00000000..4379a1b7 --- /dev/null +++ b/server/reflector/tasks/worker.py @@ -0,0 +1,6 @@ +from celery import Celery +from reflector.settings import settings + +celery = Celery(__name__) +celery.conf.broker_url = settings.CELERY_BROKER_URL +celery.conf.result_backend = settings.CELERY_RESULT_BACKEND From 00c06b7971e935f4c681a4e737c9815f6aff9133 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 25 Oct 2023 16:56:43 +0200 Subject: [PATCH 10/41] server: use redis pubsub for interprocess websocket communication --- server/docker-compose.yml | 22 +++-- server/poetry.lock | 20 +++- server/pyproject.toml | 1 + server/reflector/settings.py | 4 + server/reflector/views/transcripts.py | 61 ++++--------- server/reflector/ws_manager.py | 127 ++++++++++++++++++++++++++ 6 files changed, 183 insertions(+), 52 deletions(-) create mode 100644 server/reflector/ws_manager.py diff --git a/server/docker-compose.yml b/server/docker-compose.yml index 374130fa..4e5a21e8 100644 --- a/server/docker-compose.yml +++ b/server/docker-compose.yml @@ -1,15 +1,19 @@ version: "3.9" services: - server: - build: - context: . + # server: + # build: + # context: . + # ports: + # - 1250:1250 + # environment: + # LLM_URL: "${LLM_URL}" + # MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}" + # volumes: + # - model-cache:/root/.cache + redis: + image: redis:7.2 ports: - - 1250:1250 - environment: - LLM_URL: "${LLM_URL}" - MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}" - volumes: - - model-cache:/root/.cache + - 6379:6379 volumes: model-cache: diff --git a/server/poetry.lock b/server/poetry.lock index 0df46097..35d98382 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2919,6 +2919,24 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "redis" +version = "5.0.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "regex" version = "2023.10.3" @@ -4046,4 +4064,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c" +content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63" diff --git a/server/pyproject.toml b/server/pyproject.toml index 7b1b7936..ed231a4f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,6 +34,7 @@ sentencepiece = "^0.1.99" protobuf = "^4.24.3" profanityfilter = "^2.0.6" celery = "^5.3.4" +redis = "^5.0.1" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 1503948a..d7cc2c33 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -117,5 +117,9 @@ class Settings(BaseSettings): CELERY_BROKER_URL: str = "redis://localhost:6379/1" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" + # Redis + REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + settings = Settings() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 1d9fd4bd..9480461f 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -21,12 +21,14 @@ from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform +from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() +ws_manager = get_ws_manager() # ============================================================== # Models to move to a database, but required for the API to work @@ -487,40 +489,10 @@ async def transcript_get_websocket_events(transcript_id: str): # ============================================================== -# Websocket Manager +# Websocket # ============================================================== -class WebsocketManager: - def __init__(self): - self.active_connections = {} - - async def connect(self, transcript_id: str, websocket: WebSocket): - await websocket.accept() - if transcript_id not in self.active_connections: - self.active_connections[transcript_id] = [] - self.active_connections[transcript_id].append(websocket) - - def disconnect(self, transcript_id: str, websocket: WebSocket): - if transcript_id not in self.active_connections: - return - self.active_connections[transcript_id].remove(websocket) - if not self.active_connections[transcript_id]: - del self.active_connections[transcript_id] - - async def send_json(self, transcript_id: str, message): - if transcript_id not in self.active_connections: - return - for connection in self.active_connections[transcript_id][:]: - try: - await connection.send_json(message) - except Exception: - self.active_connections[transcript_id].remove(connection) - - -ws_manager = WebsocketManager() - - @router.websocket("/transcripts/{transcript_id}/events") async def transcript_events_websocket( transcript_id: str, @@ -532,21 +504,25 @@ async def transcript_events_websocket( if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - await ws_manager.connect(transcript_id, websocket) + # connect to websocket manager + # use ts:transcript_id as room id + room_id = f"ts:{transcript_id}" + await ws_manager.add_user_to_room(room_id, websocket) - # on first connection, send all events - for event in transcript.events: - await websocket.send_json(event.model_dump(mode="json")) - - # XXX if transcript is final (locked=True and status=ended) - # XXX send a final event to the client and close the connection - - # endless loop to wait for new events try: + # on first connection, send all events only to the current user + for event in transcript.events: + await websocket.send_json(event.model_dump(mode="json")) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + # we do not have command system now, while True: await websocket.receive() except (RuntimeError, WebSocketDisconnect): - ws_manager.disconnect(transcript_id, websocket) + await ws_manager.remove_user_from_room(room_id, websocket) # ============================================================== @@ -658,7 +634,8 @@ async def handle_rtc_event_once(event: PipelineEvent, args, data): return # transmit to websocket clients - await ws_manager.send_json(transcript_id, resp.model_dump(mode="json")) + room_id = f"ts:{transcript_id}" + await ws_manager.send_json(room_id, resp.model_dump(mode="json")) @router.post("/transcripts/{transcript_id}/record/webrtc") diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py new file mode 100644 index 00000000..43475c1d --- /dev/null +++ b/server/reflector/ws_manager.py @@ -0,0 +1,127 @@ +""" +Websocket manager +================= + +This module contains the WebsocketManager class, which is responsible for +managing websockets and handling websocket connections. + +It uses the RedisPubSubManager class to subscribe to Redis channels and +broadcast messages to all connected websockets. +""" + +import asyncio +import json + +import redis.asyncio as redis +from fastapi import WebSocket + +ws_manager = None + + +class RedisPubSubManager: + def __init__(self, host="localhost", port=6379): + self.redis_host = host + self.redis_port = port + self.redis_connection = None + self.pubsub = None + + async def get_redis_connection(self) -> redis.Redis: + return redis.Redis( + host=self.redis_host, + port=self.redis_port, + auto_close_connection_pool=False, + ) + + async def connect(self) -> None: + self.redis_connection = await self.get_redis_connection() + self.pubsub = self.redis_connection.pubsub() + + async def disconnect(self) -> None: + if self.redis_connection is None: + return + await self.redis_connection.close() + self.redis_connection = None + + async def send_json(self, room_id: str, message: str) -> None: + message = json.dumps(message) + await self.redis_connection.publish(room_id, message) + + async def subscribe(self, room_id: str) -> redis.Redis: + await self.pubsub.subscribe(room_id) + return self.pubsub + + async def unsubscribe(self, room_id: str) -> None: + await self.pubsub.unsubscribe(room_id) + + +class WebsocketManager: + def __init__(self, pubsub_client: RedisPubSubManager = None): + self.rooms: dict = {} + self.pubsub_client = pubsub_client + + async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: + await websocket.accept() + + if room_id in self.rooms: + self.rooms[room_id].append(websocket) + else: + self.rooms[room_id] = [websocket] + + await self.pubsub_client.connect() + pubsub_subscriber = await self.pubsub_client.subscribe(room_id) + asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber)) + + async def send_json(self, room_id: str, message: dict) -> None: + await self.pubsub_client.send_json(room_id, message) + + async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None: + self.rooms[room_id].remove(websocket) + + if len(self.rooms[room_id]) == 0: + del self.rooms[room_id] + await self.pubsub_client.unsubscribe(room_id) + + async def _pubsub_data_reader(self, pubsub_subscriber): + while True: + message = await pubsub_subscriber.get_message( + ignore_subscribe_messages=True + ) + if message is not None: + room_id = message["channel"].decode("utf-8") + all_sockets = self.rooms[room_id] + for socket in all_sockets: + data = json.loads(message["data"].decode("utf-8")) + await socket.send_json(data) + + +def get_pubsub_client() -> RedisPubSubManager: + """ + Returns the RedisPubSubManager instance for managing Redis pubsub. + """ + from reflector.settings import settings + + return RedisPubSubManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + ) + + +def get_ws_manager() -> WebsocketManager: + """ + Returns the WebsocketManager instance for managing websockets. + + This function initializes and returns the WebsocketManager instance, + which is responsible for managing websockets and handling websocket + connections. + + Returns: + WebsocketManager: The initialized WebsocketManager instance. + + Raises: + ImportError: If the 'reflector.settings' module cannot be imported. + RedisConnectionError: If there is an error connecting to the Redis server. + """ + global ws_manager + pubsub_client = get_pubsub_client() + ws_manager = WebsocketManager(pubsub_client=pubsub_client) + return ws_manager From 367912869d2a0834dbb4d4ccc8d9793dbf119d2f Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 25 Oct 2023 19:49:15 +0200 Subject: [PATCH 11/41] server: make processors in broadcast to be executed in parallel --- server/reflector/processors/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 6771e11e..46bfb4a5 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -290,12 +290,12 @@ class BroadcastProcessor(Processor): processor.set_pipeline(pipeline) async def _push(self, data): - for processor in self.processors: - await processor.push(data) + coros = [processor.push(data) for processor in self.processors] + await asyncio.gather(*coros) async def _flush(self): - for processor in self.processors: - await processor.flush() + coros = [processor.flush() for processor in self.processors] + await asyncio.gather(*coros) def connect(self, processor: Processor): for processor in self.processors: @@ -333,6 +333,7 @@ class Pipeline(Processor): self.logger.info("Pipeline created") self.processors = processors + self.options = None self.prefs = {} for processor in processors: From a45b30ee70d080fead2d6ae82ee1251c729a35bf Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 25 Oct 2023 19:50:08 +0200 Subject: [PATCH 12/41] www: ensure login waited before recording if you refresh the record page, it does not work and return 404 because the transcript is accessed without token --- .../[domain]/transcripts/[transcriptId]/page.tsx | 2 +- .../transcripts/[transcriptId]/record/page.tsx | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index d4f40428..3e30b97f 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -35,7 +35,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { useEffect(() => { if (requireLogin && !isAuthenticated) return; setTranscriptId(details.params.transcriptId); - }, [api]); + }, [api, details.params.transcriptId, isAuthenticated]); if (transcript?.error /** || topics?.error || waveform?.error **/) { return ( diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 8e31327c..51a318a4 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -14,6 +14,8 @@ import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faGear } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; +import { featRequireLogin } from "../../../../../app/lib/utils"; +import { useFiefIsAuthenticated } from "@fief/fief/nextjs/react"; type TranscriptDetails = { params: { @@ -36,16 +38,23 @@ const TranscriptRecord = (details: TranscriptDetails) => { } }, []); + const isAuthenticated = useFiefIsAuthenticated(); const api = getApi(); - const transcript = useTranscript(api, details.params.transcriptId); - const webRTC = useWebRTC(stream, details.params.transcriptId, api); - const webSockets = useWebSockets(details.params.transcriptId); + const [transcriptId, setTranscriptId] = useState(""); + const transcript = useTranscript(api, transcriptId); + const webRTC = useWebRTC(stream, transcriptId, api); + const webSockets = useWebSockets(transcriptId); const { audioDevices, getAudioStream } = useAudioDevice(); const [hasRecorded, setHasRecorded] = useState(false); const [transcriptStarted, setTranscriptStarted] = useState(false); + useEffect(() => { + if (featRequireLogin() && !isAuthenticated) return; + setTranscriptId(details.params.transcriptId); + }, [api, details.params.transcriptId, isAuthenticated]); + useEffect(() => { if (!transcriptStarted && webSockets.transcriptText.length !== 0) setTranscriptStarted(true); From 433c0500ccc137c56ec11284df46a679699c5982 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 25 Oct 2023 19:50:27 +0200 Subject: [PATCH 13/41] server: refactor to separate websocket management + start pipeline runner --- server/reflector/views/rtc_offer.py | 160 ++++++++++++++++++++------ server/reflector/views/transcripts.py | 17 ++- server/reflector/ws_manager.py | 23 ++-- 3 files changed, 148 insertions(+), 52 deletions(-) diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 5662d989..48d804cc 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -2,6 +2,7 @@ import asyncio from enum import StrEnum from json import dumps, loads from pathlib import Path +from typing import Callable import av from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription @@ -38,7 +39,7 @@ m_rtc_sessions = Gauge("rtc_sessions", "Number of active RTC sessions") class TranscriptionContext(object): def __init__(self, logger): self.logger = logger - self.pipeline = None + self.pipeline_runner = None self.data_channel = None self.status = "idle" self.topics = [] @@ -60,7 +61,7 @@ class AudioStreamTrack(MediaStreamTrack): ctx = self.ctx frame = await self.track.recv() try: - await ctx.pipeline.push(frame) + await ctx.pipeline_runner.push(frame) except Exception as e: ctx.logger.error("Pipeline error", error=e) return frame @@ -84,6 +85,113 @@ class PipelineEvent(StrEnum): FINAL_TITLE = "FINAL_TITLE" +class PipelineOptions(BaseModel): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + + on_transcript: Callable | None = None + on_topic: Callable | None = None + on_final_title: Callable | None = None + on_final_short_summary: Callable | None = None + on_final_long_summary: Callable | None = None + + +class PipelineRunner(object): + """ + Pipeline runner designed to be executed in a asyncio task + """ + + def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None): + self.pipeline = pipeline + self.q_cmd = asyncio.Queue() + self.ev_done = asyncio.Event() + self.status = "idle" + self.status_callback = status_callback + + async def update_status(self, status): + print("update_status", status) + self.status = status + if self.status_callback: + try: + await self.status_callback(status) + except Exception as e: + logger.error("PipelineRunner status_callback error", error=e) + + async def add_cmd(self, cmd: str, data): + await self.q_cmd.put([cmd, data]) + + async def push(self, data): + await self.add_cmd("PUSH", data) + + async def flush(self): + await self.add_cmd("FLUSH", None) + + async def run(self): + try: + await self.update_status("running") + while not self.ev_done.is_set(): + cmd, data = await self.q_cmd.get() + func = getattr(self, f"cmd_{cmd.lower()}") + if func: + await func(data) + else: + raise Exception(f"Unknown command {cmd}") + except Exception as e: + await self.update_status("error") + logger.error("PipelineRunner error", error=e) + + async def cmd_push(self, data): + if self.status == "idle": + await self.update_status("recording") + await self.pipeline.push(data) + + async def cmd_flush(self, data): + await self.update_status("processing") + await self.pipeline.flush() + await self.update_status("ended") + self.ev_done.set() + + def start(self): + print("start task") + asyncio.get_event_loop().create_task(self.run()) + + +async def pipeline_live_create(options: PipelineOptions): + # create a context for the whole rtc transaction + # add a customised logger to the context + processors = [] + if options.audio_filename is not None: + processors += [AudioFileWriterProcessor(path=options.audio_filename)] + processors += [ + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript), + TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded( + callback=options.on_final_title + ), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=options.on_final_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=options.on_final_short_summary + ), + ] + ), + ] + pipeline = Pipeline(*processors) + pipeline.options = options + pipeline.set_pref("audio:source_language", options.source_language) + pipeline.set_pref("audio:target_language", options.target_language) + + return pipeline + + async def rtc_offer_base( params: RtcOffer, request: Request, @@ -211,37 +319,24 @@ async def rtc_offer_base( data=title, ) - # create a context for the whole rtc transaction - # add a customised logger to the context - processors = [] - if audio_filename is not None: - processors += [AudioFileWriterProcessor(path=audio_filename)] - processors += [ - AudioChunkerProcessor(), - AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(), - TranscriptTranslatorProcessor.as_threaded(callback=on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=on_final_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=on_final_short_summary - ), - ] - ), - ] - ctx.pipeline = Pipeline(*processors) - ctx.pipeline.set_pref("audio:source_language", source_language) - ctx.pipeline.set_pref("audio:target_language", target_language) - # handle RTC peer connection pc = RTCPeerConnection() + # create pipeline + options = PipelineOptions( + audio_filename=audio_filename, + source_language=source_language, + target_language=target_language, + on_transcript=on_transcript, + on_topic=on_topic, + on_final_short_summary=on_final_short_summary, + on_final_long_summary=on_final_long_summary, + on_final_title=on_final_title, + ) + pipeline = await pipeline_live_create(options) + ctx.pipeline_runner = PipelineRunner(pipeline, update_status) + ctx.pipeline_runner.start() + async def flush_pipeline_and_quit(close=True): # may be called twice # 1. either the client ask to sotp the meeting @@ -249,12 +344,10 @@ async def rtc_offer_base( # - when we receive the close event, we do nothing. # 2. or the client close the connection # and there is nothing to do because it is already closed - await update_status("processing") - await ctx.pipeline.flush() + await ctx.pipeline_runner.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() - await update_status("ended") if pc in sessions: sessions.remove(pc) m_rtc_sessions.dec() @@ -287,7 +380,6 @@ async def rtc_offer_base( def on_track(track): ctx.logger.info(f"Track {track.kind} received") pc.addTrack(AudioStreamTrack(ctx, track)) - asyncio.get_event_loop().create_task(update_status("recording")) await pc.setRemoteDescription(offer) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 9480461f..9f02eb6d 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -483,16 +483,16 @@ async def transcript_get_topics( ] -@router.get("/transcripts/{transcript_id}/events") -async def transcript_get_websocket_events(transcript_id: str): - pass - - # ============================================================== # Websocket # ============================================================== +@router.get("/transcripts/{transcript_id}/events") +async def transcript_get_websocket_events(transcript_id: str): + pass + + @router.websocket("/transcripts/{transcript_id}/events") async def transcript_events_websocket( transcript_id: str, @@ -512,6 +512,13 @@ async def transcript_events_websocket( try: # on first connection, send all events only to the current user for event in transcript.events: + # for now, do not send TRANSCRIPT or STATUS options - theses are live event + # not necessary to be sent to the client; but keep the rest + name = event.event + if name == PipelineEvent.TRANSCRIPT: + continue + if name == PipelineEvent.STATUS: + continue await websocket.send_json(event.model_dump(mode="json")) # XXX if transcript is final (locked=True and status=ended) diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 43475c1d..1dfe9e3d 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -33,6 +33,8 @@ class RedisPubSubManager: ) async def connect(self) -> None: + if self.redis_connection is not None: + return self.redis_connection = await self.get_redis_connection() self.pubsub = self.redis_connection.pubsub() @@ -43,6 +45,8 @@ class RedisPubSubManager: self.redis_connection = None async def send_json(self, room_id: str, message: str) -> None: + if not self.redis_connection: + await self.connect() message = json.dumps(message) await self.redis_connection.publish(room_id, message) @@ -94,18 +98,6 @@ class WebsocketManager: await socket.send_json(data) -def get_pubsub_client() -> RedisPubSubManager: - """ - Returns the RedisPubSubManager instance for managing Redis pubsub. - """ - from reflector.settings import settings - - return RedisPubSubManager( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - ) - - def get_ws_manager() -> WebsocketManager: """ Returns the WebsocketManager instance for managing websockets. @@ -122,6 +114,11 @@ def get_ws_manager() -> WebsocketManager: RedisConnectionError: If there is an error connecting to the Redis server. """ global ws_manager - pubsub_client = get_pubsub_client() + from reflector.settings import settings + + pubsub_client = RedisPubSubManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + ) ws_manager = WebsocketManager(pubsub_client=pubsub_client) return ws_manager From 1c42473da029bb2f88b090a7323ac4f68b818275 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 26 Oct 2023 19:00:56 +0200 Subject: [PATCH 14/41] server: refactor with clearer pipeline instanciation and linked to model --- server/reflector/db/__init__.py | 23 +- server/reflector/db/transcripts.py | 284 +++++++++++++++ .../reflector/pipelines/main_live_pipeline.py | 230 ++++++++++++ server/reflector/pipelines/runner.py | 117 ++++++ server/reflector/processors/__init__.py | 8 +- server/reflector/views/rtc_offer.py | 267 +------------- server/reflector/views/transcripts.py | 341 +----------------- server/reflector/ws_manager.py | 4 +- 8 files changed, 658 insertions(+), 616 deletions(-) create mode 100644 server/reflector/db/transcripts.py create mode 100644 server/reflector/pipelines/main_live_pipeline.py create mode 100644 server/reflector/pipelines/runner.py diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index b68dfe20..9871c633 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -1,32 +1,13 @@ import databases import sqlalchemy - from reflector.events import subscribers_shutdown, subscribers_startup from reflector.settings import settings database = databases.Database(settings.DATABASE_URL) metadata = sqlalchemy.MetaData() - -transcripts = sqlalchemy.Table( - "transcript", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.String), - sqlalchemy.Column("status", sqlalchemy.String), - sqlalchemy.Column("locked", sqlalchemy.Boolean), - sqlalchemy.Column("duration", sqlalchemy.Integer), - sqlalchemy.Column("created_at", sqlalchemy.DateTime), - sqlalchemy.Column("title", sqlalchemy.String, nullable=True), - sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), - sqlalchemy.Column("topics", sqlalchemy.JSON), - sqlalchemy.Column("events", sqlalchemy.JSON), - sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), - sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), - # with user attached, optional - sqlalchemy.Column("user_id", sqlalchemy.String), -) +# import models +import reflector.db.transcripts # noqa engine = sqlalchemy.create_engine( settings.DATABASE_URL, connect_args={"check_same_thread": False} diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py new file mode 100644 index 00000000..2b9fc6b2 --- /dev/null +++ b/server/reflector/db/transcripts.py @@ -0,0 +1,284 @@ +import json +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import sqlalchemy +from pydantic import BaseModel, Field +from reflector.db import database, metadata +from reflector.processors.types import Word as ProcessorWord +from reflector.settings import settings +from reflector.utils.audio_waveform import get_audio_waveform + +transcripts = sqlalchemy.Table( + "transcript", + metadata, + sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), + sqlalchemy.Column("name", sqlalchemy.String), + sqlalchemy.Column("status", sqlalchemy.String), + sqlalchemy.Column("locked", sqlalchemy.Boolean), + sqlalchemy.Column("duration", sqlalchemy.Integer), + sqlalchemy.Column("created_at", sqlalchemy.DateTime), + sqlalchemy.Column("title", sqlalchemy.String, nullable=True), + sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), + sqlalchemy.Column("topics", sqlalchemy.JSON), + sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + # with user attached, optional + sqlalchemy.Column("user_id", sqlalchemy.String), +) + + +def generate_uuid4(): + return str(uuid4()) + + +def generate_transcript_name(): + now = datetime.utcnow() + return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" + + +class AudioWaveform(BaseModel): + data: list[float] + + +class TranscriptText(BaseModel): + text: str + translation: str | None + + +class TranscriptSegmentTopic(BaseModel): + speaker: int + text: str + timestamp: float + + +class TranscriptTopic(BaseModel): + id: str = Field(default_factory=generate_uuid4) + title: str + summary: str + timestamp: float + text: str | None = None + words: list[ProcessorWord] = [] + + +class TranscriptFinalShortSummary(BaseModel): + short_summary: str + + +class TranscriptFinalLongSummary(BaseModel): + long_summary: str + + +class TranscriptFinalTitle(BaseModel): + title: str + + +class TranscriptEvent(BaseModel): + event: str + data: dict + + +class Transcript(BaseModel): + id: str = Field(default_factory=generate_uuid4) + user_id: str | None = None + name: str = Field(default_factory=generate_transcript_name) + status: str = "idle" + locked: bool = False + duration: float = 0 + created_at: datetime = Field(default_factory=datetime.utcnow) + title: str | None = None + short_summary: str | None = None + long_summary: str | None = None + topics: list[TranscriptTopic] = [] + events: list[TranscriptEvent] = [] + source_language: str = "en" + target_language: str = "en" + + def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: + ev = TranscriptEvent(event=event, data=data.model_dump()) + self.events.append(ev) + return ev + + def upsert_topic(self, topic: TranscriptTopic): + existing_topic = next((t for t in self.topics if t.id == topic.id), None) + if existing_topic: + existing_topic.update_from(topic) + else: + self.topics.append(topic) + + def events_dump(self, mode="json"): + return [event.model_dump(mode=mode) for event in self.events] + + def topics_dump(self, mode="json"): + return [topic.model_dump(mode=mode) for topic in self.topics] + + def convert_audio_to_waveform(self, segments_count=256): + fn = self.audio_waveform_filename + if fn.exists(): + return + waveform = get_audio_waveform( + path=self.audio_mp3_filename, segments_count=segments_count + ) + try: + with open(fn, "w") as fd: + json.dump(waveform, fd) + except Exception: + # remove file if anything happen during the write + fn.unlink(missing_ok=True) + raise + return waveform + + def unlink(self): + self.data_path.unlink(missing_ok=True) + + @property + def data_path(self): + return Path(settings.DATA_DIR) / self.id + + @property + def audio_mp3_filename(self): + return self.data_path / "audio.mp3" + + @property + def audio_waveform_filename(self): + return self.data_path / "audio.json" + + @property + def audio_waveform(self): + try: + with open(self.audio_waveform_filename) as fd: + data = json.load(fd) + except json.JSONDecodeError: + # unlink file if it's corrupted + self.audio_waveform_filename.unlink(missing_ok=True) + return None + + return AudioWaveform(data=data) + + +class TranscriptController: + async def get_all( + self, + user_id: str | None = None, + order_by: str | None = None, + filter_empty: bool | None = True, + filter_recording: bool | None = True, + ) -> list[Transcript]: + """ + Get all transcripts + + If `user_id` is specified, only return transcripts that belong to the user. + Otherwise, return all anonymous transcripts. + + Parameters: + - `order_by`: field to order by, e.g. "-created_at" + - `filter_empty`: filter out empty transcripts + - `filter_recording`: filter out transcripts that are currently recording + """ + query = transcripts.select().where(transcripts.c.user_id == user_id) + + if order_by is not None: + field = getattr(transcripts.c, order_by[1:]) + if order_by.startswith("-"): + field = field.desc() + query = query.order_by(field) + + if filter_empty: + query = query.filter(transcripts.c.status != "idle") + + if filter_recording: + query = query.filter(transcripts.c.status != "recording") + + results = await database.fetch_all(query) + return results + + async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: + """ + Get a transcript by id + """ + query = transcripts.select().where(transcripts.c.id == transcript_id) + if "user_id" in kwargs: + query = query.where(transcripts.c.user_id == kwargs["user_id"]) + result = await database.fetch_one(query) + if not result: + return None + return Transcript(**result) + + async def add( + self, + name: str, + source_language: str = "en", + target_language: str = "en", + user_id: str | None = None, + ): + """ + Add a new transcript + """ + transcript = Transcript( + name=name, + source_language=source_language, + target_language=target_language, + user_id=user_id, + ) + query = transcripts.insert().values(**transcript.model_dump()) + await database.execute(query) + return transcript + + async def update(self, transcript: Transcript, values: dict): + """ + Update a transcript fields with key/values in values + """ + query = ( + transcripts.update() + .where(transcripts.c.id == transcript.id) + .values(**values) + ) + await database.execute(query) + for key, value in values.items(): + setattr(transcript, key, value) + + async def remove_by_id( + self, + transcript_id: str, + user_id: str | None = None, + ) -> None: + """ + Remove a transcript by id + """ + transcript = await self.get_by_id(transcript_id, user_id=user_id) + if not transcript: + return + if user_id is not None and transcript.user_id != user_id: + return + transcript.unlink() + query = transcripts.delete().where(transcripts.c.id == transcript_id) + await database.execute(query) + + @asynccontextmanager + async def transaction(self): + """ + A context manager for database transaction + """ + async with database.transaction(): + yield + + async def append_event( + self, + transcript: Transcript, + event: str, + data: Any, + ) -> TranscriptEvent: + """ + Append an event to a transcript + """ + resp = transcript.add_event(event=event, data=data) + await self.update(transcript, {"events": transcript.events_dump()}) + return resp + + +transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py new file mode 100644 index 00000000..30f7ead3 --- /dev/null +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -0,0 +1,230 @@ +""" +Main reflector pipeline for live streaming +========================================== + +This is the default pipeline used in the API. + +It is decoupled to: +- PipelineMainLive: have limited processing during live +- PipelineMainPost: do heavy lifting after the live + +It is directly linked to our data model. +""" + +from pathlib import Path + +from reflector.db.transcripts import ( + Transcript, + TranscriptFinalLongSummary, + TranscriptFinalShortSummary, + TranscriptFinalTitle, + TranscriptText, + TranscriptTopic, + transcripts_controller, +) +from reflector.pipelines.runner import PipelineRunner +from reflector.processors import ( + AudioChunkerProcessor, + AudioFileWriterProcessor, + AudioMergeProcessor, + AudioTranscriptAutoProcessor, + BroadcastProcessor, + Pipeline, + TranscriptFinalLongSummaryProcessor, + TranscriptFinalShortSummaryProcessor, + TranscriptFinalTitleProcessor, + TranscriptLinerProcessor, + TranscriptTopicDetectorProcessor, + TranscriptTranslatorProcessor, +) +from reflector.tasks.worker import celery +from reflector.ws_manager import WebsocketManager, get_ws_manager + + +def broadcast_to_socket(func): + """ + Decorator to broadcast transcript event to websockets + concerning this transcript + """ + + async def wrapper(self, *args, **kwargs): + resp = await func(self, *args, **kwargs) + if resp is None: + return + await self.ws_manager.send_json( + room_id=self.ws_room_id, + message=resp.model_dump(mode="json"), + ) + + return wrapper + + +class PipelineMainBase(PipelineRunner): + transcript_id: str + ws_room_id: str | None = None + ws_manager: WebsocketManager | None = None + + def prepare(self): + # prepare websocket + self.ws_room_id = f"ts:{self.transcript_id}" + self.ws_manager = get_ws_manager() + + async def get_transcript(self) -> Transcript: + # fetch the transcript + result = await transcripts_controller.get_by_id( + transcript_id=self.transcript_id + ) + if not result: + raise Exception("Transcript not found") + return result + + +class PipelineMainLive(PipelineMainBase): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + + @broadcast_to_socket + async def on_transcript(self, data): + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + return await transcripts_controller.append_event( + transcript=transcript, + event="TRANSCRIPT", + data=TranscriptText(text=data.text, translation=data.translation), + ) + + @broadcast_to_socket + async def on_topic(self, data): + topic = TranscriptTopic( + title=data.title, + summary=data.summary, + timestamp=data.timestamp, + text=data.transcript.text, + words=data.transcript.words, + ) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + return await transcripts_controller.append_event( + transcript=transcript, + event="TOPIC", + data=topic, + ) + + async def create(self) -> Pipeline: + # create a context for the whole rtc transaction + # add a customised logger to the context + self.prepare() + transcript = await self.get_transcript() + + processors = [ + AudioFileWriterProcessor(path=transcript.audio_mp3_filename), + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), + TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), + ] + pipeline = Pipeline(*processors) + pipeline.options = self + pipeline.set_pref("audio:source_language", transcript.source_language) + pipeline.set_pref("audio:target_language", transcript.target_language) + + # when the pipeline ends, connect to the post pipeline + async def on_ended(): + task_pipeline_main_post.delay(transcript_id=self.transcript_id) + + pipeline.on_ended = self + + return pipeline + + +class PipelineMainPost(PipelineMainBase): + """ + Implement the rest of the main pipeline, triggered after PipelineMainLive ended. + """ + + @broadcast_to_socket + async def on_final_title(self, data): + final_title = TranscriptFinalTitle(title=data.title) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + if not transcript.title: + transcripts_controller.update( + self.transcript, + { + "title": final_title.title, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_TITLE", + data=final_title, + ) + + @broadcast_to_socket + async def on_final_long_summary(self, data): + final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "long_summary": final_long_summary.long_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, + ) + + @broadcast_to_socket + async def on_final_short_summary(self, data): + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + async with transcripts_controller.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "short_summary": final_short_summary.short_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) + + async def create(self) -> Pipeline: + # create a context for the whole rtc transaction + # add a customised logger to the context + self.prepare() + processors = [ + # add diarization + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded( + callback=self.on_final_title + ), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_final_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_final_short_summary + ), + ] + ), + ] + pipeline = Pipeline(*processors) + pipeline.options = self + + return pipeline + + +@celery.task +def task_pipeline_main_post(transcript_id: str): + pass diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py new file mode 100644 index 00000000..ce84fec4 --- /dev/null +++ b/server/reflector/pipelines/runner.py @@ -0,0 +1,117 @@ +""" +Pipeline Runner +=============== + +Pipeline runner designed to be executed in a asyncio task. + +It is meant to be subclassed, and implement a create() method +that expose/return a Pipeline instance. + +During its lifecycle, it will emit the following status: +- started: the pipeline has been started +- push: the pipeline received at least one data +- flush: the pipeline is flushing +- ended: the pipeline has ended +- error: the pipeline has ended with an error +""" + +import asyncio +from typing import Callable + +from pydantic import BaseModel, ConfigDict +from reflector.logger import logger +from reflector.processors import Pipeline + + +class PipelineRunner(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + status: str = "idle" + on_status: Callable | None = None + on_ended: Callable | None = None + pipeline: Pipeline | None = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._q_cmd = asyncio.Queue() + self._ev_done = asyncio.Event() + self._is_first_push = True + + def create(self) -> Pipeline: + """ + Create the pipeline if not specified earlier. + Should be implemented in a subclass + """ + raise NotImplementedError() + + def start(self): + """ + Start the pipeline as a coroutine task + """ + asyncio.get_event_loop().create_task(self.run()) + + async def push(self, data): + """ + Push data to the pipeline + """ + await self._add_cmd("PUSH", data) + + async def flush(self): + """ + Flush the pipeline + """ + await self._add_cmd("FLUSH", None) + + async def _add_cmd(self, cmd: str, data): + """ + Enqueue a command to be executed in the runner. + Currently supported commands: PUSH, FLUSH + """ + await self._q_cmd.put([cmd, data]) + + async def _set_status(self, status): + print("set_status", status) + self.status = status + if self.on_status: + try: + await self.on_status(status) + except Exception as e: + logger.error("PipelineRunner status_callback error", error=e) + + async def run(self): + try: + # create the pipeline if not yet done + await self._set_status("init") + self._is_first_push = True + if not self.pipeline: + self.pipeline = await self.create() + + # start the loop + await self._set_status("started") + while not self._ev_done.is_set(): + cmd, data = await self._q_cmd.get() + func = getattr(self, f"cmd_{cmd.lower()}") + if func: + await func(data) + else: + raise Exception(f"Unknown command {cmd}") + except Exception as e: + logger.error("PipelineRunner error", error=e) + await self._set_status("error") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + + async def cmd_push(self, data): + if self._is_first_push: + await self._set_status("push") + self._is_first_push = False + await self.pipeline.push(data) + + async def cmd_flush(self, data): + await self._set_status("flush") + await self.pipeline.flush() + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 96a3941d..960c6a35 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -3,7 +3,13 @@ from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401 -from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401 +from .base import ( # noqa: F401 + BroadcastProcessor, + Pipeline, + PipelineEvent, + Processor, + ThreadedProcessor, +) from .transcript_final_long_summary import ( # noqa: F401 TranscriptFinalLongSummaryProcessor, ) diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 48d804cc..5d10c181 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,8 +1,6 @@ import asyncio from enum import StrEnum -from json import dumps, loads -from pathlib import Path -from typing import Callable +from json import loads import av from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription @@ -11,25 +9,7 @@ from prometheus_client import Gauge from pydantic import BaseModel from reflector.events import subscribers_shutdown from reflector.logger import logger -from reflector.processors import ( - AudioChunkerProcessor, - AudioFileWriterProcessor, - AudioMergeProcessor, - AudioTranscriptAutoProcessor, - FinalLongSummary, - FinalShortSummary, - Pipeline, - TitleSummary, - Transcript, - TranscriptFinalLongSummaryProcessor, - TranscriptFinalShortSummaryProcessor, - TranscriptFinalTitleProcessor, - TranscriptLinerProcessor, - TranscriptTopicDetectorProcessor, - TranscriptTranslatorProcessor, -) -from reflector.processors.base import BroadcastProcessor -from reflector.processors.types import FinalTitle +from reflector.pipelines.runner import PipelineRunner sessions = [] router = APIRouter() @@ -85,121 +65,10 @@ class PipelineEvent(StrEnum): FINAL_TITLE = "FINAL_TITLE" -class PipelineOptions(BaseModel): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" - - on_transcript: Callable | None = None - on_topic: Callable | None = None - on_final_title: Callable | None = None - on_final_short_summary: Callable | None = None - on_final_long_summary: Callable | None = None - - -class PipelineRunner(object): - """ - Pipeline runner designed to be executed in a asyncio task - """ - - def __init__(self, pipeline: Pipeline, status_callback: Callable | None = None): - self.pipeline = pipeline - self.q_cmd = asyncio.Queue() - self.ev_done = asyncio.Event() - self.status = "idle" - self.status_callback = status_callback - - async def update_status(self, status): - print("update_status", status) - self.status = status - if self.status_callback: - try: - await self.status_callback(status) - except Exception as e: - logger.error("PipelineRunner status_callback error", error=e) - - async def add_cmd(self, cmd: str, data): - await self.q_cmd.put([cmd, data]) - - async def push(self, data): - await self.add_cmd("PUSH", data) - - async def flush(self): - await self.add_cmd("FLUSH", None) - - async def run(self): - try: - await self.update_status("running") - while not self.ev_done.is_set(): - cmd, data = await self.q_cmd.get() - func = getattr(self, f"cmd_{cmd.lower()}") - if func: - await func(data) - else: - raise Exception(f"Unknown command {cmd}") - except Exception as e: - await self.update_status("error") - logger.error("PipelineRunner error", error=e) - - async def cmd_push(self, data): - if self.status == "idle": - await self.update_status("recording") - await self.pipeline.push(data) - - async def cmd_flush(self, data): - await self.update_status("processing") - await self.pipeline.flush() - await self.update_status("ended") - self.ev_done.set() - - def start(self): - print("start task") - asyncio.get_event_loop().create_task(self.run()) - - -async def pipeline_live_create(options: PipelineOptions): - # create a context for the whole rtc transaction - # add a customised logger to the context - processors = [] - if options.audio_filename is not None: - processors += [AudioFileWriterProcessor(path=options.audio_filename)] - processors += [ - AudioChunkerProcessor(), - AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(), - TranscriptTranslatorProcessor.as_threaded(callback=options.on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=options.on_topic), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded( - callback=options.on_final_title - ), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=options.on_final_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=options.on_final_short_summary - ), - ] - ), - ] - pipeline = Pipeline(*processors) - pipeline.options = options - pipeline.set_pref("audio:source_language", options.source_language) - pipeline.set_pref("audio:target_language", options.target_language) - - return pipeline - - async def rtc_offer_base( params: RtcOffer, request: Request, - event_callback=None, - event_callback_args=None, - audio_filename: Path | None = None, - source_language: str = "en", - target_language: str = "en", + pipeline_runner: PipelineRunner, ): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -209,132 +78,9 @@ async def rtc_offer_base( clientid = f"{peername[0]}:{peername[1]}" ctx = TranscriptionContext(logger=logger.bind(client=clientid)) - async def update_status(status: str): - changed = ctx.status != status - if changed: - ctx.status = status - if event_callback: - await event_callback( - event=PipelineEvent.STATUS, - args=event_callback_args, - data=StrValue(value=status), - ) - - # build pipeline callback - async def on_transcript(transcript: Transcript): - ctx.logger.info("Transcript", transcript=transcript) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "SHOW_TRANSCRIPTION", - "text": transcript.text, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.TRANSCRIPT, - args=event_callback_args, - data=transcript, - ) - - async def on_topic(topic: TitleSummary): - # FIXME: make it incremental with the frontend, not send everything - ctx.logger.info("Topic", topic=topic) - ctx.topics.append( - { - "title": topic.title, - "timestamp": topic.timestamp, - "transcript": topic.transcript.text, - "desc": topic.summary, - } - ) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = {"cmd": "UPDATE_TOPICS", "topics": ctx.topics} - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.TOPIC, args=event_callback_args, data=topic - ) - - async def on_final_short_summary(summary: FinalShortSummary): - ctx.logger.info("FinalShortSummary", final_short_summary=summary) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "DISPLAY_FINAL_SHORT_SUMMARY", - "summary": summary.short_summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_SHORT_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_long_summary(summary: FinalLongSummary): - ctx.logger.info("FinalLongSummary", final_summary=summary) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = { - "cmd": "DISPLAY_FINAL_LONG_SUMMARY", - "summary": summary.long_summary, - "duration": summary.duration, - } - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_LONG_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_title(title: FinalTitle): - ctx.logger.info("FinalTitle", final_title=title) - - # send to RTC - if ctx.data_channel.readyState == "open": - result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title} - ctx.data_channel.send(dumps(result)) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_TITLE, - args=event_callback_args, - data=title, - ) - # handle RTC peer connection pc = RTCPeerConnection() - - # create pipeline - options = PipelineOptions( - audio_filename=audio_filename, - source_language=source_language, - target_language=target_language, - on_transcript=on_transcript, - on_topic=on_topic, - on_final_short_summary=on_final_short_summary, - on_final_long_summary=on_final_long_summary, - on_final_title=on_final_title, - ) - pipeline = await pipeline_live_create(options) - ctx.pipeline_runner = PipelineRunner(pipeline, update_status) + ctx.pipeline_runner = pipeline_runner ctx.pipeline_runner.start() async def flush_pipeline_and_quit(close=True): @@ -400,8 +146,3 @@ async def rtc_clean_sessions(_): logger.debug(f"Closing session {pc}") await pc.close() sessions.clear() - - -@router.post("/offer") -async def rtc_offer(params: RtcOffer, request: Request): - return await rtc_offer_base(params, request) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 9f02eb6d..e949d645 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,8 +1,5 @@ -import json from datetime import datetime -from pathlib import Path from typing import Annotated, Optional -from uuid import uuid4 import reflector.auth as auth from fastapi import ( @@ -15,12 +12,13 @@ from fastapi import ( ) from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field -from reflector.db import database, transcripts -from reflector.logger import logger +from reflector.db.transcripts import ( + AudioWaveform, + TranscriptTopic, + transcripts_controller, +) from reflector.processors.types import Transcript as ProcessorTranscript -from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool @@ -30,216 +28,6 @@ from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() ws_manager = get_ws_manager() -# ============================================================== -# Models to move to a database, but required for the API to work -# ============================================================== - - -def generate_uuid4(): - return str(uuid4()) - - -def generate_transcript_name(): - now = datetime.utcnow() - return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" - - -class AudioWaveform(BaseModel): - data: list[float] - - -class TranscriptText(BaseModel): - text: str - translation: str | None - - -class TranscriptSegmentTopic(BaseModel): - speaker: int - text: str - timestamp: float - - -class TranscriptTopic(BaseModel): - id: str = Field(default_factory=generate_uuid4) - title: str - summary: str - timestamp: float - text: str | None = None - words: list[ProcessorWord] = [] - - -class TranscriptFinalShortSummary(BaseModel): - short_summary: str - - -class TranscriptFinalLongSummary(BaseModel): - long_summary: str - - -class TranscriptFinalTitle(BaseModel): - title: str - - -class TranscriptEvent(BaseModel): - event: str - data: dict - - -class Transcript(BaseModel): - id: str = Field(default_factory=generate_uuid4) - user_id: str | None = None - name: str = Field(default_factory=generate_transcript_name) - status: str = "idle" - locked: bool = False - duration: float = 0 - created_at: datetime = Field(default_factory=datetime.utcnow) - title: str | None = None - short_summary: str | None = None - long_summary: str | None = None - topics: list[TranscriptTopic] = [] - events: list[TranscriptEvent] = [] - source_language: str = "en" - target_language: str = "en" - - def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: - ev = TranscriptEvent(event=event, data=data.model_dump()) - self.events.append(ev) - return ev - - def upsert_topic(self, topic: TranscriptTopic): - existing_topic = next((t for t in self.topics if t.id == topic.id), None) - if existing_topic: - existing_topic.update_from(topic) - else: - self.topics.append(topic) - - def events_dump(self, mode="json"): - return [event.model_dump(mode=mode) for event in self.events] - - def topics_dump(self, mode="json"): - return [topic.model_dump(mode=mode) for topic in self.topics] - - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform - - def unlink(self): - self.data_path.unlink(missing_ok=True) - - @property - def data_path(self): - return Path(settings.DATA_DIR) / self.id - - @property - def audio_mp3_filename(self): - return self.data_path / "audio.mp3" - - @property - def audio_waveform_filename(self): - return self.data_path / "audio.json" - - @property - def audio_waveform(self): - try: - with open(self.audio_waveform_filename) as fd: - data = json.load(fd) - except json.JSONDecodeError: - # unlink file if it's corrupted - self.audio_waveform_filename.unlink(missing_ok=True) - return None - - return AudioWaveform(data=data) - - -class TranscriptController: - async def get_all( - self, - user_id: str | None = None, - order_by: str | None = None, - filter_empty: bool | None = False, - filter_recording: bool | None = False, - ) -> list[Transcript]: - query = transcripts.select().where(transcripts.c.user_id == user_id) - - if order_by is not None: - field = getattr(transcripts.c, order_by[1:]) - if order_by.startswith("-"): - field = field.desc() - query = query.order_by(field) - - if filter_empty: - query = query.filter(transcripts.c.status != "idle") - - if filter_recording: - query = query.filter(transcripts.c.status != "recording") - - results = await database.fetch_all(query) - return results - - async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: - query = transcripts.select().where(transcripts.c.id == transcript_id) - if "user_id" in kwargs: - query = query.where(transcripts.c.user_id == kwargs["user_id"]) - result = await database.fetch_one(query) - if not result: - return None - return Transcript(**result) - - async def add( - self, - name: str, - source_language: str = "en", - target_language: str = "en", - user_id: str | None = None, - ): - transcript = Transcript( - name=name, - source_language=source_language, - target_language=target_language, - user_id=user_id, - ) - query = transcripts.insert().values(**transcript.model_dump()) - await database.execute(query) - return transcript - - async def update(self, transcript: Transcript, values: dict): - query = ( - transcripts.update() - .where(transcripts.c.id == transcript.id) - .values(**values) - ) - await database.execute(query) - for key, value in values.items(): - setattr(transcript, key, value) - - async def remove_by_id( - self, transcript_id: str, user_id: str | None = None - ) -> None: - transcript = await self.get_by_id(transcript_id, user_id=user_id) - if not transcript: - return - if user_id is not None and transcript.user_id != user_id: - return - transcript.unlink() - query = transcripts.delete().where(transcripts.c.id == transcript_id) - await database.execute(query) - - -transcripts_controller = TranscriptController() - - # ============================================================== # Transcripts list # ============================================================== @@ -537,114 +325,6 @@ async def transcript_events_websocket( # ============================================================== -async def handle_rtc_event(event: PipelineEvent, args, data): - try: - return await handle_rtc_event_once(event, args, data) - except Exception: - logger.exception("Error handling RTC event") - - -async def handle_rtc_event_once(event: PipelineEvent, args, data): - # OFC the current implementation is not good, - # but it's just a POC before persistence. It won't query the - # transcript from the database for each event. - # print(f"Event: {event}", args, data) - transcript_id = args - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - return - - # event send to websocket clients may not be the same as the event - # received from the pipeline. For example, the pipeline will send - # a TRANSCRIPT event with all words, but this is not what we want - # to send to the websocket client. - - # FIXME don't do copy - if event == PipelineEvent.TRANSCRIPT: - resp = transcript.add_event( - event=event, - data=TranscriptText(text=data.text, translation=data.translation), - ) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - }, - ) - - elif event == PipelineEvent.TOPIC: - topic = TranscriptTopic( - title=data.title, - summary=data.summary, - timestamp=data.timestamp, - text=data.transcript.text, - words=data.transcript.words, - ) - resp = transcript.add_event(event=event, data=topic) - transcript.upsert_topic(topic) - - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "topics": transcript.topics_dump(), - }, - ) - - elif event == PipelineEvent.FINAL_TITLE: - final_title = TranscriptFinalTitle(title=data.title) - resp = transcript.add_event(event=event, data=final_title) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "title": final_title.title, - }, - ) - - elif event == PipelineEvent.FINAL_LONG_SUMMARY: - final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) - resp = transcript.add_event(event=event, data=final_long_summary) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "long_summary": final_long_summary.long_summary, - }, - ) - - elif event == PipelineEvent.FINAL_SHORT_SUMMARY: - final_short_summary = TranscriptFinalShortSummary( - short_summary=data.short_summary - ) - resp = transcript.add_event(event=event, data=final_short_summary) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "short_summary": final_short_summary.short_summary, - }, - ) - - elif event == PipelineEvent.STATUS: - resp = transcript.add_event(event=event, data=data) - await transcripts_controller.update( - transcript, - { - "events": transcript.events_dump(), - "status": data.value, - }, - ) - - else: - logger.warning(f"Unknown event: {event}") - return - - # transmit to websocket clients - room_id = f"ts:{transcript_id}" - await ws_manager.send_json(room_id, resp.model_dump(mode="json")) - - @router.post("/transcripts/{transcript_id}/record/webrtc") async def transcript_record_webrtc( transcript_id: str, @@ -660,13 +340,14 @@ async def transcript_record_webrtc( if transcript.locked: raise HTTPException(status_code=400, detail="Transcript is locked") + # create a pipeline runner + from reflector.pipelines.main_live_pipeline import PipelineMainLive + + pipeline_runner = PipelineMainLive(transcript_id=transcript_id) + # FIXME do not allow multiple recording at the same time return await rtc_offer_base( params, request, - event_callback=handle_rtc_event, - event_callback_args=transcript_id, - audio_filename=transcript.audio_mp3_filename, - source_language=transcript.source_language, - target_language=transcript.target_language, + pipeline_runner=pipeline_runner, ) diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 1dfe9e3d..7650807b 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -14,6 +14,7 @@ import json import redis.asyncio as redis from fastapi import WebSocket +from reflector.settings import settings ws_manager = None @@ -114,7 +115,8 @@ def get_ws_manager() -> WebsocketManager: RedisConnectionError: If there is an error connecting to the Redis server. """ global ws_manager - from reflector.settings import settings + if ws_manager: + return ws_manager pubsub_client = RedisPubSubManager( host=settings.REDIS_HOST, From 07c4d080c22fb896982be1da9928fc16ffd0b6c8 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 27 Oct 2023 15:59:27 +0200 Subject: [PATCH 15/41] server: refactor with diarization, logic works --- server/poetry.lock | 16 +- server/pyproject.toml | 1 + server/reflector/db/transcripts.py | 14 +- .../reflector/pipelines/main_live_pipeline.py | 256 ++++++++++++------ server/reflector/pipelines/runner.py | 47 +++- server/reflector/processors/__init__.py | 1 + .../reflector/processors/audio_diarization.py | 65 +++++ server/reflector/processors/types.py | 5 + server/reflector/tasks/boot.py | 2 - server/reflector/tasks/worker.py | 6 - server/reflector/views/rtc_offer.py | 18 +- server/reflector/views/transcripts.py | 25 +- server/reflector/worker/app.py | 11 + .../{tasks => worker}/post_transcript.py | 0 server/reflector/ws_manager.py | 10 +- server/tests/conftest.py | 25 +- server/tests/test_transcripts_rtc_ws.py | 54 +++- 17 files changed, 387 insertions(+), 169 deletions(-) create mode 100644 server/reflector/processors/audio_diarization.py delete mode 100644 server/reflector/tasks/boot.py delete mode 100644 server/reflector/tasks/worker.py create mode 100644 server/reflector/worker/app.py rename server/reflector/{tasks => worker}/post_transcript.py (100%) diff --git a/server/poetry.lock b/server/poetry.lock index 35d98382..8783625b 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2676,6 +2676,20 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "pytest-celery" +version = "0.0.0" +description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest" +optional = false +python-versions = "*" +files = [ + {file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"}, + {file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"}, +] + +[package.dependencies] +celery = ">=4.4.0" + [[package]] name = "pytest-cov" version = "4.1.0" @@ -4064,4 +4078,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6d2e8a8e0d5d928481f9a33210d44863a1921e18147fa57dc6889d877697aa63" +content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d" diff --git a/server/pyproject.toml b/server/pyproject.toml index ed231a4f..c8614006 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -49,6 +49,7 @@ pytest-asyncio = "^0.21.1" pytest = "^7.4.0" httpx-ws = "^0.4.1" pytest-httpx = "^0.23.1" +pytest-celery = "^0.0.0" [tool.poetry.group.aws.dependencies] diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 2b9fc6b2..61a2c380 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -62,6 +62,7 @@ class TranscriptTopic(BaseModel): title: str summary: str timestamp: float + duration: float | None = 0 text: str | None = None words: list[ProcessorWord] = [] @@ -264,7 +265,7 @@ class TranscriptController: """ A context manager for database transaction """ - async with database.transaction(): + async with database.transaction(isolation="serializable"): yield async def append_event( @@ -280,5 +281,16 @@ class TranscriptController: await self.update(transcript, {"events": transcript.events_dump()}) return resp + async def upsert_topic( + self, + transcript: Transcript, + topic: TranscriptTopic, + ) -> TranscriptEvent: + """ + Append an event to a transcript + """ + transcript.upsert_topic(topic) + await self.update(transcript, {"topics": transcript.topics_dump()}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 30f7ead3..4159c889 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -11,8 +11,12 @@ It is decoupled to: It is directly linked to our data model. """ +import asyncio +from contextlib import asynccontextmanager from pathlib import Path +from celery import shared_task +from pydantic import BaseModel from reflector.db.transcripts import ( Transcript, TranscriptFinalLongSummary, @@ -25,6 +29,7 @@ from reflector.db.transcripts import ( from reflector.pipelines.runner import PipelineRunner from reflector.processors import ( AudioChunkerProcessor, + AudioDiarizationProcessor, AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, @@ -37,11 +42,13 @@ from reflector.processors import ( TranscriptTopicDetectorProcessor, TranscriptTranslatorProcessor, ) -from reflector.tasks.worker import celery +from reflector.processors.types import AudioDiarizationInput +from reflector.processors.types import TitleSummary as TitleSummaryProcessorType +from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.ws_manager import WebsocketManager, get_ws_manager -def broadcast_to_socket(func): +def broadcast_to_sockets(func): """ Decorator to broadcast transcript event to websockets concerning this transcript @@ -59,6 +66,10 @@ def broadcast_to_socket(func): return wrapper +class StrValue(BaseModel): + value: str + + class PipelineMainBase(PipelineRunner): transcript_id: str ws_room_id: str | None = None @@ -66,6 +77,7 @@ class PipelineMainBase(PipelineRunner): def prepare(self): # prepare websocket + self._lock = asyncio.Lock() self.ws_room_id = f"ts:{self.transcript_id}" self.ws_manager = get_ws_manager() @@ -78,15 +90,59 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + @asynccontextmanager + async def transaction(self): + async with self._lock: + async with transcripts_controller.transaction(): + yield -class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + @broadcast_to_sockets + async def on_status(self, status): + # if it's the first part, update the status of the transcript + # but do not set the ended status yet. + if isinstance(self, PipelineMainLive): + status_mapping = { + "started": "recording", + "push": "recording", + "flush": "processing", + "error": "error", + } + elif isinstance(self, PipelineMainDiarization): + status_mapping = { + "push": "processing", + "flush": "processing", + "error": "error", + "ended": "ended", + } + else: + raise Exception(f"Runner {self.__class__} is missing status mapping") - @broadcast_to_socket + # mutate to model status + status = status_mapping.get(status) + if not status: + return + + # when the status of the pipeline changes, update the transcript + async with self.transaction(): + transcript = await self.get_transcript() + if status == transcript.status: + return + resp = await transcripts_controller.append_event( + transcript=transcript, + event="STATUS", + data=StrValue(value=status), + ) + await transcripts_controller.update( + transcript, + { + "status": status, + }, + ) + return resp + + @broadcast_to_sockets async def on_transcript(self, data): - async with transcripts_controller.transaction(): + async with self.transaction(): transcript = await self.get_transcript() return await transcripts_controller.append_event( transcript=transcript, @@ -94,7 +150,7 @@ class PipelineMainLive(PipelineMainBase): data=TranscriptText(text=data.text, translation=data.translation), ) - @broadcast_to_socket + @broadcast_to_sockets async def on_topic(self, data): topic = TranscriptTopic( title=data.title, @@ -103,14 +159,75 @@ class PipelineMainLive(PipelineMainBase): text=data.transcript.text, words=data.transcript.words, ) - async with transcripts_controller.transaction(): + async with self.transaction(): transcript = await self.get_transcript() + await transcripts_controller.upsert_topic(transcript, topic) return await transcripts_controller.append_event( transcript=transcript, event="TOPIC", data=topic, ) + @broadcast_to_sockets + async def on_title(self, data): + final_title = TranscriptFinalTitle(title=data.title) + async with self.transaction(): + transcript = await self.get_transcript() + if not transcript.title: + transcripts_controller.update( + transcript, + { + "title": final_title.title, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_TITLE", + data=final_title, + ) + + @broadcast_to_sockets + async def on_long_summary(self, data): + final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "long_summary": final_long_summary.long_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_LONG_SUMMARY", + data=final_long_summary, + ) + + @broadcast_to_sockets + async def on_short_summary(self, data): + final_short_summary = TranscriptFinalShortSummary( + short_summary=data.short_summary + ) + async with self.transaction(): + transcript = await self.get_transcript() + await transcripts_controller.update( + transcript, + { + "short_summary": final_short_summary.short_summary, + }, + ) + return await transcripts_controller.append_event( + transcript=transcript, + event="FINAL_SHORT_SUMMARY", + data=final_short_summary, + ) + + +class PipelineMainLive(PipelineMainBase): + audio_filename: Path | None = None + source_language: str = "en" + target_language: str = "en" + async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context @@ -125,96 +242,49 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalLongSummaryProcessor.as_threaded( + callback=self.on_long_summary + ), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ), ] pipeline = Pipeline(*processors) pipeline.options = self pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) - # when the pipeline ends, connect to the post pipeline - async def on_ended(): - task_pipeline_main_post.delay(transcript_id=self.transcript_id) - - pipeline.on_ended = self - return pipeline + async def on_ended(self): + # when the pipeline ends, connect to the post pipeline + task_pipeline_main_post.delay(transcript_id=self.transcript_id) -class PipelineMainPost(PipelineMainBase): + +class PipelineMainDiarization(PipelineMainBase): """ - Implement the rest of the main pipeline, triggered after PipelineMainLive ended. + Diarization is a long time process, so we do it in a separate pipeline + When done, adjust the short and final summary """ - @broadcast_to_socket - async def on_final_title(self, data): - final_title = TranscriptFinalTitle(title=data.title) - async with transcripts_controller.transaction(): - transcript = await self.get_transcript() - if not transcript.title: - transcripts_controller.update( - self.transcript, - { - "title": final_title.title, - }, - ) - return await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_TITLE", - data=final_title, - ) - - @broadcast_to_socket - async def on_final_long_summary(self, data): - final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) - async with transcripts_controller.transaction(): - transcript = await self.get_transcript() - await transcripts_controller.update( - transcript, - { - "long_summary": final_long_summary.long_summary, - }, - ) - return await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_LONG_SUMMARY", - data=final_long_summary, - ) - - @broadcast_to_socket - async def on_final_short_summary(self, data): - final_short_summary = TranscriptFinalShortSummary( - short_summary=data.short_summary - ) - async with transcripts_controller.transaction(): - transcript = await self.get_transcript() - await transcripts_controller.update( - transcript, - { - "short_summary": final_short_summary.short_summary, - }, - ) - return await transcripts_controller.append_event( - transcript=transcript, - event="FINAL_SHORT_SUMMARY", - data=final_short_summary, - ) - async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() processors = [ - # add diarization + AudioDiarizationProcessor(), BroadcastProcessor( processors=[ - TranscriptFinalTitleProcessor.as_threaded( - callback=self.on_final_title - ), TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_final_long_summary + callback=self.on_long_summary ), TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_final_short_summary + callback=self.on_short_summary ), ] ), @@ -222,9 +292,35 @@ class PipelineMainPost(PipelineMainBase): pipeline = Pipeline(*processors) pipeline.options = self + # now let's start the pipeline by pushing information to the + # first processor diarization processor + # XXX translation is lost when converting our data model to the processor model + transcript = await self.get_transcript() + topics = [ + TitleSummaryProcessorType( + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + + audio_diarization_input = AudioDiarizationInput( + audio_filename=transcript.audio_mp3_filename, + topics=topics, + ) + + # as tempting to use pipeline.push, prefer to use the runner + # to let the start just do one job. + self.push(audio_diarization_input) + self.flush() + return pipeline -@celery.task +@shared_task def task_pipeline_main_post(transcript_id: str): - pass + runner = PipelineMainDiarization(transcript_id=transcript_id) + runner.start_sync() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index ce84fec4..0575cf96 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -16,7 +16,6 @@ During its lifecycle, it will emit the following status: """ import asyncio -from typing import Callable from pydantic import BaseModel, ConfigDict from reflector.logger import logger @@ -27,8 +26,6 @@ class PipelineRunner(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) status: str = "idle" - on_status: Callable | None = None - on_ended: Callable | None = None pipeline: Pipeline | None = None def __init__(self, **kwargs): @@ -36,6 +33,10 @@ class PipelineRunner(BaseModel): self._q_cmd = asyncio.Queue() self._ev_done = asyncio.Event() self._is_first_push = True + self._logger = logger.bind( + runner=id(self), + runner_cls=self.__class__.__name__, + ) def create(self) -> Pipeline: """ @@ -50,33 +51,51 @@ class PipelineRunner(BaseModel): """ asyncio.get_event_loop().create_task(self.run()) - async def push(self, data): + def start_sync(self): + """ + Start the pipeline synchronously (for non-asyncio apps) + """ + asyncio.run(self.run()) + + def push(self, data): """ Push data to the pipeline """ - await self._add_cmd("PUSH", data) + self._add_cmd("PUSH", data) - async def flush(self): + def flush(self): """ Flush the pipeline """ - await self._add_cmd("FLUSH", None) + self._add_cmd("FLUSH", None) - async def _add_cmd(self, cmd: str, data): + async def on_status(self, status): + """ + Called when the status of the pipeline changes + """ + pass + + async def on_ended(self): + """ + Called when the pipeline ends + """ + pass + + def _add_cmd(self, cmd: str, data): """ Enqueue a command to be executed in the runner. Currently supported commands: PUSH, FLUSH """ - await self._q_cmd.put([cmd, data]) + self._q_cmd.put_nowait([cmd, data]) async def _set_status(self, status): - print("set_status", status) + self._logger.debug("Runner status updated", status=status) self.status = status if self.on_status: try: await self.on_status(status) - except Exception as e: - logger.error("PipelineRunner status_callback error", error=e) + except Exception: + self._logger.exception("Runer error while setting status") async def run(self): try: @@ -95,8 +114,8 @@ class PipelineRunner(BaseModel): await func(data) else: raise Exception(f"Unknown command {cmd}") - except Exception as e: - logger.error("PipelineRunner error", error=e) + except Exception: + self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() if self.on_ended: diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 960c6a35..01a3a174 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,4 +1,5 @@ from .audio_chunker import AudioChunkerProcessor # noqa: F401 +from .audio_diarization import AudioDiarizationProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py new file mode 100644 index 00000000..8db8e8e5 --- /dev/null +++ b/server/reflector/processors/audio_diarization.py @@ -0,0 +1,65 @@ +from reflector.processors.base import Processor +from reflector.processors.types import AudioDiarizationInput, TitleSummary + + +class AudioDiarizationProcessor(Processor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + async def _push(self, data: AudioDiarizationInput): + # Gather diarization data + diarization = [ + {"start": 0.0, "stop": 4.9, "speaker": 2}, + {"start": 5.6, "stop": 6.7, "speaker": 2}, + {"start": 7.3, "stop": 8.9, "speaker": 2}, + {"start": 7.3, "stop": 7.9, "speaker": 0}, + {"start": 9.4, "stop": 11.2, "speaker": 2}, + {"start": 9.7, "stop": 10.0, "speaker": 0}, + {"start": 10.0, "stop": 10.1, "speaker": 0}, + {"start": 11.7, "stop": 16.1, "speaker": 2}, + {"start": 11.8, "stop": 12.1, "speaker": 1}, + {"start": 16.4, "stop": 21.0, "speaker": 2}, + {"start": 21.1, "stop": 22.6, "speaker": 2}, + {"start": 24.7, "stop": 31.9, "speaker": 2}, + {"start": 32.0, "stop": 32.8, "speaker": 1}, + {"start": 33.4, "stop": 37.8, "speaker": 2}, + {"start": 37.9, "stop": 40.3, "speaker": 0}, + {"start": 39.2, "stop": 40.4, "speaker": 2}, + {"start": 40.7, "stop": 41.4, "speaker": 0}, + {"start": 41.6, "stop": 45.7, "speaker": 2}, + {"start": 46.4, "stop": 53.1, "speaker": 2}, + {"start": 53.6, "stop": 56.5, "speaker": 2}, + {"start": 54.9, "stop": 75.4, "speaker": 1}, + {"start": 57.3, "stop": 58.0, "speaker": 2}, + {"start": 65.7, "stop": 66.0, "speaker": 2}, + {"start": 75.8, "stop": 78.8, "speaker": 1}, + {"start": 79.0, "stop": 82.6, "speaker": 1}, + {"start": 83.2, "stop": 83.3, "speaker": 1}, + {"start": 84.5, "stop": 94.3, "speaker": 1}, + {"start": 95.1, "stop": 100.7, "speaker": 1}, + {"start": 100.7, "stop": 102.0, "speaker": 0}, + {"start": 100.7, "stop": 101.8, "speaker": 1}, + {"start": 102.0, "stop": 103.0, "speaker": 1}, + {"start": 103.0, "stop": 103.7, "speaker": 0}, + {"start": 103.7, "stop": 103.8, "speaker": 1}, + {"start": 103.8, "stop": 113.9, "speaker": 0}, + {"start": 114.7, "stop": 117.0, "speaker": 0}, + {"start": 117.0, "stop": 117.4, "speaker": 1}, + ] + + # now reapply speaker to topics (if any) + # topics is a list[BaseModel] with an attribute words + # words is a list[BaseModel] with text, start and speaker attribute + + print("IN DIARIZATION PROCESSOR", data) + + # mutate in place + for topic in data.topics: + for word in topic.transcript.words: + for d in diarization: + if d["start"] <= word.start <= d["stop"]: + word.speaker = d["speaker"] + + # emit them + for topic in data.topics: + await self.emit(topic) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index d2c32d17..3ec21491 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -382,3 +382,8 @@ class TranslationLanguages(BaseModel): def is_supported(self, lang_id: str) -> bool: return lang_id in self.supported_languages + + +class AudioDiarizationInput(BaseModel): + audio_filename: Path + topics: list[TitleSummary] diff --git a/server/reflector/tasks/boot.py b/server/reflector/tasks/boot.py deleted file mode 100644 index 88cc2d6f..00000000 --- a/server/reflector/tasks/boot.py +++ /dev/null @@ -1,2 +0,0 @@ -import reflector.tasks.post_transcript # noqa -import reflector.tasks.worker # noqa diff --git a/server/reflector/tasks/worker.py b/server/reflector/tasks/worker.py deleted file mode 100644 index 4379a1b7..00000000 --- a/server/reflector/tasks/worker.py +++ /dev/null @@ -1,6 +0,0 @@ -from celery import Celery -from reflector.settings import settings - -celery = Celery(__name__) -celery.conf.broker_url = settings.CELERY_BROKER_URL -celery.conf.result_backend = settings.CELERY_RESULT_BACKEND diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 5d10c181..386ada9c 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,5 +1,4 @@ import asyncio -from enum import StrEnum from json import loads import av @@ -41,7 +40,7 @@ class AudioStreamTrack(MediaStreamTrack): ctx = self.ctx frame = await self.track.recv() try: - await ctx.pipeline_runner.push(frame) + ctx.pipeline_runner.push(frame) except Exception as e: ctx.logger.error("Pipeline error", error=e) return frame @@ -52,19 +51,6 @@ class RtcOffer(BaseModel): type: str -class StrValue(BaseModel): - value: str - - -class PipelineEvent(StrEnum): - TRANSCRIPT = "TRANSCRIPT" - TOPIC = "TOPIC" - FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY" - STATUS = "STATUS" - FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY" - FINAL_TITLE = "FINAL_TITLE" - - async def rtc_offer_base( params: RtcOffer, request: Request, @@ -90,7 +76,7 @@ async def rtc_offer_base( # - when we receive the close event, we do nothing. # 2. or the client close the connection # and there is nothing to do because it is already closed - await ctx.pipeline_runner.flush() + ctx.pipeline_runner.flush() if close: ctx.logger.debug("Closing peer connection") await pc.close() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index e949d645..31cbe28e 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -23,10 +23,9 @@ from reflector.ws_manager import get_ws_manager from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response -from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base +from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() -ws_manager = get_ws_manager() # ============================================================== # Transcripts list @@ -166,32 +165,17 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {"events": []} + values = {} if info.name is not None: values["name"] = info.name if info.locked is not None: values["locked"] = info.locked if info.long_summary is not None: values["long_summary"] = info.long_summary - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY: - transcript_event["long_summary"] = info.long_summary - break - values["events"].extend(transcript.events) if info.short_summary is not None: values["short_summary"] = info.short_summary - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY: - transcript_event["short_summary"] = info.short_summary - break - values["events"].extend(transcript.events) if info.title is not None: values["title"] = info.title - for transcript_event in transcript.events: - if transcript_event["event"] == PipelineEvent.FINAL_TITLE: - transcript_event["title"] = info.title - break - values["events"].extend(transcript.events) await transcripts_controller.update(transcript, values) return transcript @@ -295,6 +279,7 @@ async def transcript_events_websocket( # connect to websocket manager # use ts:transcript_id as room id room_id = f"ts:{transcript_id}" + ws_manager = get_ws_manager() await ws_manager.add_user_to_room(room_id, websocket) try: @@ -303,9 +288,7 @@ async def transcript_events_websocket( # for now, do not send TRANSCRIPT or STATUS options - theses are live event # not necessary to be sent to the client; but keep the rest name = event.event - if name == PipelineEvent.TRANSCRIPT: - continue - if name == PipelineEvent.STATUS: + if name in ("TRANSCRIPT", "STATUS"): continue await websocket.send_json(event.model_dump(mode="json")) diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py new file mode 100644 index 00000000..3714a64d --- /dev/null +++ b/server/reflector/worker/app.py @@ -0,0 +1,11 @@ +from celery import Celery +from reflector.settings import settings + +app = Celery(__name__) +app.conf.broker_url = settings.CELERY_BROKER_URL +app.conf.result_backend = settings.CELERY_RESULT_BACKEND +app.autodiscover_tasks( + [ + "reflector.pipelines.main_live_pipeline", + ] +) diff --git a/server/reflector/tasks/post_transcript.py b/server/reflector/worker/post_transcript.py similarity index 100% rename from server/reflector/tasks/post_transcript.py rename to server/reflector/worker/post_transcript.py diff --git a/server/reflector/ws_manager.py b/server/reflector/ws_manager.py index 7650807b..a84e3361 100644 --- a/server/reflector/ws_manager.py +++ b/server/reflector/ws_manager.py @@ -11,13 +11,12 @@ broadcast messages to all connected websockets. import asyncio import json +import threading import redis.asyncio as redis from fastapi import WebSocket from reflector.settings import settings -ws_manager = None - class RedisPubSubManager: def __init__(self, host="localhost", port=6379): @@ -114,13 +113,14 @@ def get_ws_manager() -> WebsocketManager: ImportError: If the 'reflector.settings' module cannot be imported. RedisConnectionError: If there is an error connecting to the Redis server. """ - global ws_manager - if ws_manager: - return ws_manager + local = threading.local() + if hasattr(local, "ws_manager"): + return local.ws_manager pubsub_client = RedisPubSubManager( host=settings.REDIS_HOST, port=settings.REDIS_PORT, ) ws_manager = WebsocketManager(pubsub_client=pubsub_client) + local.ws_manager = ws_manager return ws_manager diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 76b56abf..d5f5f0b9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -45,17 +45,16 @@ async def dummy_transcript(): from reflector.processors.types import AudioFile, Transcript, Word class TestAudioTranscriptProcessor(AudioTranscriptProcessor): - async def _transcript(self, data: AudioFile): - source_language = self.get_pref("audio:source_language", "en") - print("transcripting", source_language) - print("pipeline", self.pipeline) - print("prefs", self.pipeline.prefs) + _time_idx = 0 + async def _transcript(self, data: AudioFile): + i = self._time_idx + self._time_idx += 2 return Transcript( text="Hello world.", words=[ - Word(start=0.0, end=1.0, text="Hello"), - Word(start=1.0, end=2.0, text=" world."), + Word(start=i, end=i + 1, text="Hello", speaker=0), + Word(start=i + 1, end=i + 2, text=" world.", speaker=0), ], ) @@ -98,7 +97,17 @@ def ensure_casing(): @pytest.fixture def sentence_tokenize(): with patch( - "reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize" + "reflector.processors.TranscriptFinalLongSummaryProcessor.sentence_tokenize" ) as mock_sent_tokenize: mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"] yield + + +@pytest.fixture(scope="session") +def celery_enable_logging(): + return True + + +@pytest.fixture(scope="session") +def celery_config(): + return {"broker_url": "memory://", "result_backend": "rpc"} diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 50e74231..e2bfee32 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -32,7 +32,7 @@ class ThreadedUvicorn: @pytest.fixture -async def appserver(tmpdir): +async def appserver(tmpdir, celery_session_app, celery_session_worker): from reflector.settings import settings from reflector.app import app @@ -52,6 +52,13 @@ async def appserver(tmpdir): settings.DATA_DIR = DATA_DIR +@pytest.fixture(scope="session") +def celery_includes(): + return ["reflector.pipelines.main_live_pipeline"] + + +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") @pytest.mark.asyncio async def test_transcript_rtc_and_websocket( tmpdir, @@ -121,14 +128,20 @@ async def test_transcript_rtc_and_websocket( # XXX aiortc is long to close the connection # instead of waiting a long time, we just send a STOP client.channel.send(json.dumps({"cmd": "STOP"})) - - # wait the processing to finish - await asyncio.sleep(2) - await client.stop() # wait the processing to finish - await asyncio.sleep(2) + timeout = 20 + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] in ("ended", "error"): + break + await asyncio.sleep(1) + + if resp.json()["status"] != "ended": + raise TimeoutError("Timeout while waiting for transcript to be ended") # stop websocket task websocket_task.cancel() @@ -152,7 +165,7 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["transcript"].startswith("Hello world.") + assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames @@ -169,23 +182,21 @@ async def test_transcript_rtc_and_websocket( # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] - assert statuses == ["recording", "processing", "ended"] + assert statuses.index("recording") < statuses.index("processing") + assert statuses.index("processing") < statuses.index("ended") # ensure the last event received is ended assert events[-1]["event"] == "STATUS" assert events[-1]["data"]["value"] == "ended" - # check that transcript status in model is updated - resp = await ac.get(f"/transcripts/{tid}") - assert resp.status_code == 200 - assert resp.json()["status"] == "ended" - # check that audio/mp3 is available resp = await ac.get(f"/transcripts/{tid}/audio/mp3") assert resp.status_code == 200 assert resp.headers["Content-Type"] == "audio/mpeg" +@pytest.mark.usefixtures("celery_session_app") +@pytest.mark.usefixtures("celery_session_worker") @pytest.mark.asyncio async def test_transcript_rtc_and_websocket_and_fr( tmpdir, @@ -265,6 +276,18 @@ async def test_transcript_rtc_and_websocket_and_fr( await client.stop() # wait the processing to finish + timeout = 20 + while True: + # fetch the transcript and check if it is ended + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + if resp.json()["status"] == "ended": + break + await asyncio.sleep(1) + + if resp.json()["status"] != "ended": + raise TimeoutError("Timeout while waiting for transcript to be ended") + await asyncio.sleep(2) # stop websocket task @@ -289,7 +312,7 @@ async def test_transcript_rtc_and_websocket_and_fr( ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["transcript"].startswith("Hello world.") + assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames @@ -306,7 +329,8 @@ async def test_transcript_rtc_and_websocket_and_fr( # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] - assert statuses == ["recording", "processing", "ended"] + assert statuses.index("recording") < statuses.index("processing") + assert statuses.index("processing") < statuses.index("ended") # ensure the last event received is ended assert events[-1]["event"] == "STATUS" From d8a842f099091ad1ad19b934a4ff1fadb3003a95 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 27 Oct 2023 20:00:07 +0200 Subject: [PATCH 16/41] server: full diarization processor implementation based on gokul app --- server/reflector/app.py | 3 + .../reflector/pipelines/main_live_pipeline.py | 30 +++++---- server/reflector/pipelines/runner.py | 6 +- server/reflector/processors/__init__.py | 2 +- .../reflector/processors/audio_diarization.py | 65 ------------------- .../processors/audio_diarization_auto.py | 34 ++++++++++ .../processors/audio_diarization_base.py | 28 ++++++++ .../processors/audio_diarization_modal.py | 36 ++++++++++ .../processors/audio_transcript_auto.py | 34 ++-------- server/reflector/processors/types.py | 2 +- server/reflector/settings.py | 10 +++ .../tools/start_post_main_live_pipeline.py | 14 ++++ server/reflector/views/transcripts.py | 29 ++++++++- server/reflector/worker/app.py | 1 + server/tests/test_transcripts_rtc_ws.py | 2 + 15 files changed, 186 insertions(+), 110 deletions(-) delete mode 100644 server/reflector/processors/audio_diarization.py create mode 100644 server/reflector/processors/audio_diarization_auto.py create mode 100644 server/reflector/processors/audio_diarization_base.py create mode 100644 server/reflector/processors/audio_diarization_modal.py create mode 100644 server/reflector/tools/start_post_main_live_pipeline.py diff --git a/server/reflector/app.py b/server/reflector/app.py index 758faf69..c2e3bf7e 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) +# prepare celery +from reflector.worker import app as celery_app # noqa + # simpler openapi id def use_route_names_as_operation_ids(app: FastAPI) -> None: diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 4159c889..87e2ff46 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -13,10 +13,12 @@ It is directly linked to our data model. import asyncio from contextlib import asynccontextmanager +from datetime import timedelta from pathlib import Path from celery import shared_task from pydantic import BaseModel +from reflector.app import app from reflector.db.transcripts import ( Transcript, TranscriptFinalLongSummary, @@ -29,7 +31,7 @@ from reflector.db.transcripts import ( from reflector.pipelines.runner import PipelineRunner from reflector.processors import ( AudioChunkerProcessor, - AudioDiarizationProcessor, + AudioDiarizationAutoProcessor, AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, @@ -45,6 +47,7 @@ from reflector.processors import ( from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import TitleSummary as TitleSummaryProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType +from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager @@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner): async with self.transaction(): transcript = await self.get_transcript() if not transcript.title: - transcripts_controller.update( + await transcripts_controller.update( transcript, { "title": final_title.title, @@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase): AudioFileWriterProcessor(path=transcript.audio_mp3_filename), AudioChunkerProcessor(), AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), + AudioTranscriptAutoProcessor.get_instance().as_threaded(), TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), BroadcastProcessor( processors=[ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), - TranscriptFinalLongSummaryProcessor.as_threaded( - callback=self.on_long_summary - ), - TranscriptFinalShortSummaryProcessor.as_threaded( - callback=self.on_short_summary - ), ] ), ] @@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase): # add a customised logger to the context self.prepare() processors = [ - AudioDiarizationProcessor(), + AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic), BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase): for topic in transcript.topics ] + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from reflector.views.transcripts import create_access_token + + token = create_access_token( + {"sub": transcript.user_id}, + expires_delta=timedelta(minutes=15), + ) + path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id) + url = f"{settings.BASE_URL}{path}?token={token}" audio_diarization_input = AudioDiarizationInput( - audio_filename=transcript.audio_mp3_filename, + audio_url=url, topics=topics, ) diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 0575cf96..583cdcb6 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -55,7 +55,11 @@ class PipelineRunner(BaseModel): """ Start the pipeline synchronously (for non-asyncio apps) """ - asyncio.run(self.run()) + loop = asyncio.get_event_loop() + if not loop: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.run()) def push(self, data): """ diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 01a3a174..1c88d6c5 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -1,5 +1,5 @@ from .audio_chunker import AudioChunkerProcessor # noqa: F401 -from .audio_diarization import AudioDiarizationProcessor # noqa: F401 +from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401 diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py deleted file mode 100644 index 8db8e8e5..00000000 --- a/server/reflector/processors/audio_diarization.py +++ /dev/null @@ -1,65 +0,0 @@ -from reflector.processors.base import Processor -from reflector.processors.types import AudioDiarizationInput, TitleSummary - - -class AudioDiarizationProcessor(Processor): - INPUT_TYPE = AudioDiarizationInput - OUTPUT_TYPE = TitleSummary - - async def _push(self, data: AudioDiarizationInput): - # Gather diarization data - diarization = [ - {"start": 0.0, "stop": 4.9, "speaker": 2}, - {"start": 5.6, "stop": 6.7, "speaker": 2}, - {"start": 7.3, "stop": 8.9, "speaker": 2}, - {"start": 7.3, "stop": 7.9, "speaker": 0}, - {"start": 9.4, "stop": 11.2, "speaker": 2}, - {"start": 9.7, "stop": 10.0, "speaker": 0}, - {"start": 10.0, "stop": 10.1, "speaker": 0}, - {"start": 11.7, "stop": 16.1, "speaker": 2}, - {"start": 11.8, "stop": 12.1, "speaker": 1}, - {"start": 16.4, "stop": 21.0, "speaker": 2}, - {"start": 21.1, "stop": 22.6, "speaker": 2}, - {"start": 24.7, "stop": 31.9, "speaker": 2}, - {"start": 32.0, "stop": 32.8, "speaker": 1}, - {"start": 33.4, "stop": 37.8, "speaker": 2}, - {"start": 37.9, "stop": 40.3, "speaker": 0}, - {"start": 39.2, "stop": 40.4, "speaker": 2}, - {"start": 40.7, "stop": 41.4, "speaker": 0}, - {"start": 41.6, "stop": 45.7, "speaker": 2}, - {"start": 46.4, "stop": 53.1, "speaker": 2}, - {"start": 53.6, "stop": 56.5, "speaker": 2}, - {"start": 54.9, "stop": 75.4, "speaker": 1}, - {"start": 57.3, "stop": 58.0, "speaker": 2}, - {"start": 65.7, "stop": 66.0, "speaker": 2}, - {"start": 75.8, "stop": 78.8, "speaker": 1}, - {"start": 79.0, "stop": 82.6, "speaker": 1}, - {"start": 83.2, "stop": 83.3, "speaker": 1}, - {"start": 84.5, "stop": 94.3, "speaker": 1}, - {"start": 95.1, "stop": 100.7, "speaker": 1}, - {"start": 100.7, "stop": 102.0, "speaker": 0}, - {"start": 100.7, "stop": 101.8, "speaker": 1}, - {"start": 102.0, "stop": 103.0, "speaker": 1}, - {"start": 103.0, "stop": 103.7, "speaker": 0}, - {"start": 103.7, "stop": 103.8, "speaker": 1}, - {"start": 103.8, "stop": 113.9, "speaker": 0}, - {"start": 114.7, "stop": 117.0, "speaker": 0}, - {"start": 117.0, "stop": 117.4, "speaker": 1}, - ] - - # now reapply speaker to topics (if any) - # topics is a list[BaseModel] with an attribute words - # words is a list[BaseModel] with text, start and speaker attribute - - print("IN DIARIZATION PROCESSOR", data) - - # mutate in place - for topic in data.topics: - for word in topic.transcript.words: - for d in diarization: - if d["start"] <= word.start <= d["stop"]: - word.speaker = d["speaker"] - - # emit them - for topic in data.topics: - await self.emit(topic) diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py new file mode 100644 index 00000000..1de19b45 --- /dev/null +++ b/server/reflector/processors/audio_diarization_auto.py @@ -0,0 +1,34 @@ +import importlib + +from reflector.processors.base import Processor +from reflector.settings import settings + + +class AudioDiarizationAutoProcessor(Processor): + _registry = {} + + @classmethod + def register(cls, name, kclass): + cls._registry[name] = kclass + + @classmethod + def get_instance(cls, name: str | None = None, **kwargs): + if name is None: + name = settings.DIARIZATION_BACKEND + + if name not in cls._registry: + module_name = f"reflector.processors.audio_diarization_{name}" + importlib.import_module(module_name) + + # gather specific configuration for the processor + # search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` + config = {} + name_upper = name.upper() + settings_prefix = "DIARIZATION_" + config_prefix = f"{settings_prefix}{name_upper}_" + for key, value in settings: + if key.startswith(config_prefix): + config_name = key[len(settings_prefix) :].lower() + config[config_name] = value + + return cls._registry[name](**config | kwargs) diff --git a/server/reflector/processors/audio_diarization_base.py b/server/reflector/processors/audio_diarization_base.py new file mode 100644 index 00000000..2ad7e4bf --- /dev/null +++ b/server/reflector/processors/audio_diarization_base.py @@ -0,0 +1,28 @@ +from reflector.processors.base import Processor +from reflector.processors.types import AudioDiarizationInput, TitleSummary + + +class AudioDiarizationBaseProcessor(Processor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + async def _push(self, data: AudioDiarizationInput): + diarization = await self._diarize(data) + + # now reapply speaker to topics (if any) + # topics is a list[BaseModel] with an attribute words + # words is a list[BaseModel] with text, start and speaker attribute + + # mutate in place + for topic in data.topics: + for word in topic.transcript.words: + for d in diarization: + if d["start"] <= word.start <= d["end"]: + word.speaker = d["speaker"] + + # emit them + for topic in data.topics: + await self.emit(topic) + + async def _diarize(self, data: AudioDiarizationInput): + raise NotImplementedError diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py new file mode 100644 index 00000000..b71dbcc9 --- /dev/null +++ b/server/reflector/processors/audio_diarization_modal.py @@ -0,0 +1,36 @@ +import httpx +from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor +from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor +from reflector.processors.types import AudioDiarizationInput, TitleSummary +from reflector.settings import settings + + +class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor): + INPUT_TYPE = AudioDiarizationInput + OUTPUT_TYPE = TitleSummary + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.diarization_url = settings.DIARIZATION_URL + "/diarize" + self.headers = { + "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}", + } + + async def _diarize(self, data: AudioDiarizationInput): + # Gather diarization data + params = { + "audio_file_url": data.audio_url, + "timestamp": 0, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.diarization_url, + headers=self.headers, + params=params, + timeout=None, + ) + response.raise_for_status() + return response.json()["text"] + + +AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index f223a52d..fc1f0b5e 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -1,8 +1,6 @@ import importlib from reflector.processors.audio_transcript import AudioTranscriptProcessor -from reflector.processors.base import Pipeline, Processor -from reflector.processors.types import AudioFile from reflector.settings import settings @@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): cls._registry[name] = kclass @classmethod - def get_instance(cls, name): + def get_instance(cls, name: str | None = None, **kwargs): + if name is None: + name = settings.TRANSCRIPT_BACKEND if name not in cls._registry: module_name = f"reflector.processors.audio_transcript_{name}" importlib.import_module(module_name) @@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): config_name = key[len(settings_prefix) :].lower() config[config_name] = value - return cls._registry[name](**config) - - def __init__(self, **kwargs): - self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND) - super().__init__(**kwargs) - - def set_pipeline(self, pipeline: Pipeline): - super().set_pipeline(pipeline) - self.processor.set_pipeline(pipeline) - - def connect(self, processor: Processor): - self.processor.connect(processor) - - def disconnect(self, processor: Processor): - self.processor.disconnect(processor) - - def on(self, callback): - self.processor.on(callback) - - def off(self, callback): - self.processor.off(callback) - - async def _push(self, data: AudioFile): - return await self.processor._push(data) - - async def _flush(self): - return await self.processor._flush() + return cls._registry[name](**config | kwargs) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 3ec21491..b67f84b9 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel): class AudioDiarizationInput(BaseModel): - audio_filename: Path + audio_url: str topics: list[TitleSummary] diff --git a/server/reflector/settings.py b/server/reflector/settings.py index d7cc2c33..021d509f 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -89,6 +89,10 @@ class Settings(BaseSettings): # LLM Modal configuration LLM_MODAL_API_KEY: str | None = None + # Diarization + DIARIZATION_BACKEND: str = "modal" + DIARIZATION_URL: str | None = None + # Sentry SENTRY_DSN: str | None = None @@ -121,5 +125,11 @@ class Settings(BaseSettings): REDIS_HOST: str = "localhost" REDIS_PORT: int = 6379 + # Secret key + SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21" + + # Current hosting/domain + BASE_URL: str = "http://localhost:1250" + settings = Settings() diff --git a/server/reflector/tools/start_post_main_live_pipeline.py b/server/reflector/tools/start_post_main_live_pipeline.py new file mode 100644 index 00000000..859f03a4 --- /dev/null +++ b/server/reflector/tools/start_post_main_live_pipeline.py @@ -0,0 +1,14 @@ +import argparse + +from reflector.app import celery_app # noqa +from reflector.pipelines.main_live_pipeline import task_pipeline_main_post + +parser = argparse.ArgumentParser() +parser.add_argument("transcript_id", type=str) +parser.add_argument("--delay", action="store_true") +args = parser.parse_args() + +if args.delay: + task_pipeline_main_post.delay(args.transcript_id) +else: + task_pipeline_main_post(args.transcript_id) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 31cbe28e..f83bc6de 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from typing import Annotated, Optional import reflector.auth as auth @@ -9,8 +9,10 @@ from fastapi import ( Request, WebSocket, WebSocketDisconnect, + status, ) from fastapi_pagination import Page, paginate +from jose import jwt from pydantic import BaseModel, Field from reflector.db.transcripts import ( AudioWaveform, @@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() +ALGORITHM = "HS256" +DOWNLOAD_EXPIRE_MINUTES = 60 + + +def create_access_token(data: dict, expires_delta: timedelta): + to_encode = data.copy() + expire = datetime.utcnow() + expires_delta + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + # ============================================================== # Transcripts list # ============================================================== @@ -198,8 +212,21 @@ async def transcript_get_audio_mp3( request: Request, transcript_id: str, user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + token: str | None = None, ): user_id = user["sub"] if user else None + if not user_id and token: + unauthorized_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: str = payload.get("sub") + except jwt.JWTError: + raise unauthorized_exception + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py index 3714a64d..e1000364 100644 --- a/server/reflector/worker/app.py +++ b/server/reflector/worker/app.py @@ -4,6 +4,7 @@ from reflector.settings import settings app = Celery(__name__) app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.result_backend = settings.CELERY_RESULT_BACKEND +app.conf.broker_connection_retry_on_startup = True app.autodiscover_tasks( [ "reflector.pipelines.main_live_pipeline", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index e2bfee32..5a9a404b 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket( print("Test websocket: DISCONNECTED") websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + print("Test websocket: TASK CREATED", websocket_task) # create stream client import argparse @@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr( print("Test websocket: DISCONNECTED") websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + print("Test websocket: TASK CREATED", websocket_task) # create stream client import argparse From e405ccb8f38c6033cce17119047cc77ae84a8eac Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 27 Oct 2023 20:08:00 +0200 Subject: [PATCH 17/41] server: started updating documentation --- README.md | 11 +++++++++-- docker-compose.yml | 14 +++++++++++--- server/docker-compose.yml | 28 ++++++++++++++++++---------- server/runserver.sh | 9 ++++++++- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 22651cc6..150c4575 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ The project architecture consists of three primary components: * **Front-End**: NextJS React project hosted on Vercel, located in `www/`. * **Back-End**: Python server that offers an API and data persistence, found in `server/`. -* **AI Models**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. +* **GPU implementation**: Providing services such as speech-to-text transcription, topic generation, automated summaries, and translations. It also uses https://github.com/fief-dev for authentication, and Vercel for deployment and configuration of the front-end. @@ -120,6 +120,9 @@ TRANSCRIPT_MODAL_API_KEY= LLM_BACKEND=modal LLM_URL=https://monadical-sas--reflector-llm-web.modal.run LLM_MODAL_API_KEY= +TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run +ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run +DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run AUTH_BACKEND=fief AUTH_FIEF_URL=https://auth.reflector.media/reflector-local @@ -138,6 +141,10 @@ Use: poetry run python3 -m reflector.app ``` +And start the background worker + +celery -A reflector.worker.app worker --loglevel=info + #### Using docker Use: @@ -161,4 +168,4 @@ poetry run python -m reflector.tools.process path/to/audio.wav ## AI Models -*(Documentation for this section is pending.)* \ No newline at end of file +*(Documentation for this section is pending.)* diff --git a/docker-compose.yml b/docker-compose.yml index 934baaac..9e6519af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,10 +5,19 @@ services: context: server ports: - 1250:1250 - environment: - LLM_URL: "${LLM_URL}" volumes: - model-cache:/root/.cache + environment: ENTRYPOINT=server + worker: + build: + context: server + volumes: + - model-cache:/root/.cache + environment: ENTRYPOINT=worker + redis: + image: redis:7.2 + ports: + - 6379:6379 web: build: context: www @@ -17,4 +26,3 @@ services: volumes: model-cache: - diff --git a/server/docker-compose.yml b/server/docker-compose.yml index 4e5a21e8..c8432816 100644 --- a/server/docker-compose.yml +++ b/server/docker-compose.yml @@ -1,15 +1,23 @@ version: "3.9" services: - # server: - # build: - # context: . - # ports: - # - 1250:1250 - # environment: - # LLM_URL: "${LLM_URL}" - # MIN_TRANSCRIPT_LENGTH: "${MIN_TRANSCRIPT_LENGTH}" - # volumes: - # - model-cache:/root/.cache + server: + build: + context: . + ports: + - 1250:1250 + volumes: + - model-cache:/root/.cache + environment: + ENTRYPOINT: server + REDIS_HOST: redis + worker: + build: + context: . + volumes: + - model-cache:/root/.cache + environment: + ENTRYPOINT: worker + REDIS_HOST: redis redis: image: redis:7.2 ports: diff --git a/server/runserver.sh b/server/runserver.sh index 38eafe09..b0c3f138 100755 --- a/server/runserver.sh +++ b/server/runserver.sh @@ -4,4 +4,11 @@ if [ -f "/venv/bin/activate" ]; then source /venv/bin/activate fi alembic upgrade head -python -m reflector.app + +if [ "${ENTRYPOINT}" = "server" ]; then + python -m reflector.app +elif [ "${ENTRYPOINT}" = "worker" ]; then + celery -A reflector.worker.app worker --loglevel=info +else + echo "Unknown command" +fi From d0057ae2c4cc29167cca12ff65ed0415d7e64ee0 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 1 Nov 2023 10:28:15 +0100 Subject: [PATCH 18/41] server: add missing python-jose --- server/poetry.lock | 67 ++++++++++++++++++++++++++++++++++++++++++- server/pyproject.toml | 1 + 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/server/poetry.lock b/server/poetry.lock index 8783625b..e72ade57 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1110,6 +1110,24 @@ idna = ["idna (>=2.1,<4.0)"] trio = ["trio (>=0.14,<0.23)"] wmi = ["wmi (>=1.5.1,<2.0.0)"] +[[package]] +name = "ecdsa" +version = "0.18.0" +description = "ECDSA cryptographic signature library (pure python)" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"}, + {file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"}, +] + +[package.dependencies] +six = ">=1.9.0" + +[package.extras] +gmpy = ["gmpy"] +gmpy2 = ["gmpy2"] + [[package]] name = "fastapi" version = "0.100.1" @@ -2348,6 +2366,17 @@ files = [ {file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"}, ] +[[package]] +name = "pyasn1" +version = "0.5.0" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, + {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"}, +] + [[package]] name = "pycparser" version = "2.21" @@ -2754,6 +2783,28 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-jose" +version = "3.3.0" +description = "JOSE implementation in Python" +optional = false +python-versions = "*" +files = [ + {file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"}, + {file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"}, +] + +[package.dependencies] +cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"cryptography\""} +ecdsa = "!=0.15" +pyasn1 = "*" +rsa = "*" + +[package.extras] +cryptography = ["cryptography (>=3.4.0)"] +pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] +pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -3069,6 +3120,20 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "s3transfer" version = "0.6.2" @@ -4078,4 +4143,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "07e42e7512fd5d51b656207a05092c53905c15e6a5ce548e015cdc05bd1baa7d" +content-hash = "cfefbd402bde7585caa42c1a889be0496d956e285bb05db9e1e7ae5e485e91fe" diff --git a/server/pyproject.toml b/server/pyproject.toml index c8614006..7681af39 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -35,6 +35,7 @@ protobuf = "^4.24.3" profanityfilter = "^2.0.6" celery = "^5.3.4" redis = "^5.0.1" +python-jose = {extras = ["cryptography"], version = "^3.3.0"} [tool.poetry.group.dev.dependencies] From 4da890b95fc2952e3ad8b679036e9a88632ea09a Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 1 Nov 2023 11:55:46 +0100 Subject: [PATCH 19/41] server: add dummy diarization and fixes instanciation --- .../reflector/pipelines/main_live_pipeline.py | 25 ++++++++++++++++--- server/reflector/pipelines/runner.py | 7 ++---- ...arization_base.py => audio_diarization.py} | 2 +- .../processors/audio_diarization_auto.py | 7 +++--- .../processors/audio_diarization_modal.py | 4 +-- .../processors/audio_transcript_auto.py | 3 +-- server/tests/conftest.py | 25 ++++++++++++++++++- server/tests/test_transcripts_rtc_ws.py | 2 ++ 8 files changed, 57 insertions(+), 18 deletions(-) rename server/reflector/processors/{audio_diarization_base.py => audio_diarization.py} (95%) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 87e2ff46..88e1bffd 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -28,6 +28,7 @@ from reflector.db.transcripts import ( TranscriptTopic, transcripts_controller, ) +from reflector.logger import logger from reflector.pipelines.runner import PipelineRunner from reflector.processors import ( AudioChunkerProcessor, @@ -241,7 +242,7 @@ class PipelineMainLive(PipelineMainBase): AudioFileWriterProcessor(path=transcript.audio_mp3_filename), AudioChunkerProcessor(), AudioMergeProcessor(), - AudioTranscriptAutoProcessor.get_instance().as_threaded(), + AudioTranscriptAutoProcessor.as_threaded(), TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), @@ -255,11 +256,18 @@ class PipelineMainLive(PipelineMainBase): pipeline.options = self pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info( + "Pipeline main live created", + transcript_id=self.transcript_id, + ) return pipeline async def on_ended(self): # when the pipeline ends, connect to the post pipeline + logger.info("Pipeline main live ended", transcript_id=self.transcript_id) + logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) task_pipeline_main_post.delay(transcript_id=self.transcript_id) @@ -274,7 +282,7 @@ class PipelineMainDiarization(PipelineMainBase): # add a customised logger to the context self.prepare() processors = [ - AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic), + AudioDiarizationAutoProcessor(callback=self.on_topic), BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -313,7 +321,10 @@ class PipelineMainDiarization(PipelineMainBase): {"sub": transcript.user_id}, expires_delta=timedelta(minutes=15), ) - path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id) + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=transcript.id, + ) url = f"{settings.BASE_URL}{path}?token={token}" audio_diarization_input = AudioDiarizationInput( audio_url=url, @@ -322,6 +333,10 @@ class PipelineMainDiarization(PipelineMainBase): # as tempting to use pipeline.push, prefer to use the runner # to let the start just do one job. + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info( + "Pipeline main post created", transcript_id=self.transcript_id + ) self.push(audio_diarization_input) self.flush() @@ -330,5 +345,9 @@ class PipelineMainDiarization(PipelineMainBase): @shared_task def task_pipeline_main_post(transcript_id: str): + logger.info( + "Starting main post pipeline", + transcript_id=transcript_id, + ) runner = PipelineMainDiarization(transcript_id=transcript_id) runner.start_sync() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index 583cdcb6..a1e137a7 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -55,11 +55,8 @@ class PipelineRunner(BaseModel): """ Start the pipeline synchronously (for non-asyncio apps) """ - loop = asyncio.get_event_loop() - if not loop: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self.run()) + coro = self.run() + asyncio.run(coro) def push(self, data): """ diff --git a/server/reflector/processors/audio_diarization_base.py b/server/reflector/processors/audio_diarization.py similarity index 95% rename from server/reflector/processors/audio_diarization_base.py rename to server/reflector/processors/audio_diarization.py index 2ad7e4bf..d69f4b80 100644 --- a/server/reflector/processors/audio_diarization_base.py +++ b/server/reflector/processors/audio_diarization.py @@ -2,7 +2,7 @@ from reflector.processors.base import Processor from reflector.processors.types import AudioDiarizationInput, TitleSummary -class AudioDiarizationBaseProcessor(Processor): +class AudioDiarizationProcessor(Processor): INPUT_TYPE = AudioDiarizationInput OUTPUT_TYPE = TitleSummary diff --git a/server/reflector/processors/audio_diarization_auto.py b/server/reflector/processors/audio_diarization_auto.py index 1de19b45..0e7bfc5c 100644 --- a/server/reflector/processors/audio_diarization_auto.py +++ b/server/reflector/processors/audio_diarization_auto.py @@ -1,18 +1,17 @@ import importlib -from reflector.processors.base import Processor +from reflector.processors.audio_diarization import AudioDiarizationProcessor from reflector.settings import settings -class AudioDiarizationAutoProcessor(Processor): +class AudioDiarizationAutoProcessor(AudioDiarizationProcessor): _registry = {} @classmethod def register(cls, name, kclass): cls._registry[name] = kclass - @classmethod - def get_instance(cls, name: str | None = None, **kwargs): + def __new__(cls, name: str | None = None, **kwargs): if name is None: name = settings.DIARIZATION_BACKEND diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index b71dbcc9..52be7c5d 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -1,11 +1,11 @@ import httpx +from reflector.processors.audio_diarization import AudioDiarizationProcessor from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor -from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor from reflector.processors.types import AudioDiarizationInput, TitleSummary from reflector.settings import settings -class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor): +class AudioDiarizationModalProcessor(AudioDiarizationProcessor): INPUT_TYPE = AudioDiarizationInput OUTPUT_TYPE = TitleSummary diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index fc1f0b5e..ac79ced0 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -11,8 +11,7 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): def register(cls, name, kclass): cls._registry[name] = kclass - @classmethod - def get_instance(cls, name: str | None = None, **kwargs): + def __new__(cls, name: str | None = None, **kwargs): if name is None: name = settings.TRANSCRIPT_BACKEND if name not in cls._registry: diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d5f5f0b9..aafca9fd 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -60,12 +60,35 @@ async def dummy_transcript(): with patch( "reflector.processors.audio_transcript_auto" - ".AudioTranscriptAutoProcessor.get_instance" + ".AudioTranscriptAutoProcessor.__new__" ) as mock_audio: mock_audio.return_value = TestAudioTranscriptProcessor() yield +@pytest.fixture +async def dummy_diarization(): + from reflector.processors.audio_diarization import AudioDiarizationProcessor + + class TestAudioDiarizationProcessor(AudioDiarizationProcessor): + _time_idx = 0 + + async def _diarize(self, data): + i = self._time_idx + self._time_idx += 2 + return [ + {"start": i, "end": i + 1, "speaker": 0}, + {"start": i + 1, "end": i + 2, "speaker": 1}, + ] + + with patch( + "reflector.processors.audio_diarization_auto" + ".AudioDiarizationAutoProcessor.__new__" + ) as mock_audio: + mock_audio.return_value = TestAudioDiarizationProcessor() + yield + + @pytest.fixture async def dummy_llm(): from reflector.llm.base import LLM diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 5a9a404b..8f8cac71 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -65,6 +65,7 @@ async def test_transcript_rtc_and_websocket( dummy_llm, dummy_transcript, dummy_processors, + dummy_diarization, ensure_casing, appserver, sentence_tokenize, @@ -204,6 +205,7 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_llm, dummy_transcript, dummy_processors, + dummy_diarization, ensure_casing, appserver, sentence_tokenize, From 3e7031d031567d46d77e316a4f40e51cf8dcaf85 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 1 Nov 2023 11:58:15 +0100 Subject: [PATCH 20/41] server: do not remove empty or recording transcripts by default We should have the possibility to delete or hide them later --- server/reflector/db/transcripts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 61a2c380..89025d53 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -167,8 +167,8 @@ class TranscriptController: self, user_id: str | None = None, order_by: str | None = None, - filter_empty: bool | None = True, - filter_recording: bool | None = True, + filter_empty: bool | None = False, + filter_recording: bool | None = False, ) -> list[Transcript]: """ Get all transcripts From 7fca7ae287b9131f21a3bb06cfbf114fd3441dc6 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Wed, 1 Nov 2023 12:00:33 +0100 Subject: [PATCH 21/41] ci: add redis service required for celery --- .github/workflows/test_server.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml index 71bed5d0..9f3b9a6a 100644 --- a/.github/workflows/test_server.yml +++ b/.github/workflows/test_server.yml @@ -11,6 +11,11 @@ on: jobs: pytest: runs-on: ubuntu-latest + services: + redis: + image: redis:6 + ports: + - 6379:6379 steps: - uses: actions/checkout@v3 - name: Install poetry From dbf3c9fd2cfb845a3161394a2460a8f73fb83c2c Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 16:40:48 +0100 Subject: [PATCH 22/41] www: fix path --- www/app/[domain]/transcripts/topicList.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx index ef4a2889..56b02e6e 100644 --- a/www/app/[domain]/transcripts/topicList.tsx +++ b/www/app/[domain]/transcripts/topicList.tsx @@ -7,7 +7,7 @@ import { import { formatTime } from "../../lib/time"; import ScrollToBottom from "./scrollToBottom"; import { Topic } from "./webSocketTypes"; -import { generateHighContrastColor } from "../lib/utils"; +import { generateHighContrastColor } from "../../lib/utils"; type TopicListProps = { topics: Topic[]; From f5ce3dd75eb48cd8ececd4a34aea8afc887271f8 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 16:41:06 +0100 Subject: [PATCH 23/41] www: remove auth changes (will be done by sara PR) --- .../transcripts/[transcriptId]/record/page.tsx | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 51a318a4..8e31327c 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -14,8 +14,6 @@ import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faGear } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; -import { featRequireLogin } from "../../../../../app/lib/utils"; -import { useFiefIsAuthenticated } from "@fief/fief/nextjs/react"; type TranscriptDetails = { params: { @@ -38,23 +36,16 @@ const TranscriptRecord = (details: TranscriptDetails) => { } }, []); - const isAuthenticated = useFiefIsAuthenticated(); const api = getApi(); - const [transcriptId, setTranscriptId] = useState(""); - const transcript = useTranscript(api, transcriptId); - const webRTC = useWebRTC(stream, transcriptId, api); - const webSockets = useWebSockets(transcriptId); + const transcript = useTranscript(api, details.params.transcriptId); + const webRTC = useWebRTC(stream, details.params.transcriptId, api); + const webSockets = useWebSockets(details.params.transcriptId); const { audioDevices, getAudioStream } = useAudioDevice(); const [hasRecorded, setHasRecorded] = useState(false); const [transcriptStarted, setTranscriptStarted] = useState(false); - useEffect(() => { - if (featRequireLogin() && !isAuthenticated) return; - setTranscriptId(details.params.transcriptId); - }, [api, details.params.transcriptId, isAuthenticated]); - useEffect(() => { if (!transcriptStarted && webSockets.transcriptText.length !== 0) setTranscriptStarted(true); From bba5643237141a711c7a47aa58a1149d2dc07311 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 16:43:10 +0100 Subject: [PATCH 24/41] www: update openapi --- www/app/api/apis/DefaultApi.ts | 61 +++------------------------------- 1 file changed, 5 insertions(+), 56 deletions(-) diff --git a/www/app/api/apis/DefaultApi.ts b/www/app/api/apis/DefaultApi.ts index d51d42ca..5bb2e7e9 100644 --- a/www/app/api/apis/DefaultApi.ts +++ b/www/app/api/apis/DefaultApi.ts @@ -42,10 +42,6 @@ import { UpdateTranscriptToJSON, } from "../models"; -export interface RtcOfferRequest { - rtcOffer: RtcOffer; -} - export interface V1TranscriptDeleteRequest { transcriptId: any; } @@ -56,6 +52,7 @@ export interface V1TranscriptGetRequest { export interface V1TranscriptGetAudioMp3Request { transcriptId: any; + token?: any; } export interface V1TranscriptGetAudioWaveformRequest { @@ -132,58 +129,6 @@ export class DefaultApi extends runtime.BaseAPI { return await response.value(); } - /** - * Rtc Offer - */ - async rtcOfferRaw( - requestParameters: RtcOfferRequest, - initOverrides?: RequestInit | runtime.InitOverrideFunction, - ): Promise> { - if ( - requestParameters.rtcOffer === null || - requestParameters.rtcOffer === undefined - ) { - throw new runtime.RequiredError( - "rtcOffer", - "Required parameter requestParameters.rtcOffer was null or undefined when calling rtcOffer.", - ); - } - - const queryParameters: any = {}; - - const headerParameters: runtime.HTTPHeaders = {}; - - headerParameters["Content-Type"] = "application/json"; - - const response = await this.request( - { - path: `/offer`, - method: "POST", - headers: headerParameters, - query: queryParameters, - body: RtcOfferToJSON(requestParameters.rtcOffer), - }, - initOverrides, - ); - - if (this.isJsonMime(response.headers.get("content-type"))) { - return new runtime.JSONApiResponse(response); - } else { - return new runtime.TextApiResponse(response) as any; - } - } - - /** - * Rtc Offer - */ - async rtcOffer( - requestParameters: RtcOfferRequest, - initOverrides?: RequestInit | runtime.InitOverrideFunction, - ): Promise { - const response = await this.rtcOfferRaw(requestParameters, initOverrides); - return await response.value(); - } - /** * Transcript Delete */ @@ -325,6 +270,10 @@ export class DefaultApi extends runtime.BaseAPI { const queryParameters: any = {}; + if (requestParameters.token !== undefined) { + queryParameters["token"] = requestParameters.token; + } + const headerParameters: runtime.HTTPHeaders = {}; if (this.configuration && this.configuration.accessToken) { From 19b5ba2c4c40c08f8cff57e71a7f62829436504b Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 16:53:57 +0100 Subject: [PATCH 25/41] server: add diarization logger information --- .../processors/audio_diarization_modal.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index 52be7c5d..3c8d1b45 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -22,15 +22,21 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor): "audio_file_url": data.audio_url, "timestamp": 0, } + self.logger.info("Diarization started", audio_file_url=data.audio_url) async with httpx.AsyncClient() as client: - response = await client.post( - self.diarization_url, - headers=self.headers, - params=params, - timeout=None, - ) - response.raise_for_status() - return response.json()["text"] + try: + response = await client.post( + self.diarization_url, + headers=self.headers, + params=params, + timeout=None, + ) + response.raise_for_status() + self.logger.info("Diarization finished") + return response.json()["text"] + except Exception: + self.logger.exception("Diarization failed after retrying") + raise AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) From 057c636c56c0529086f89811cada698c98251962 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 16:56:30 +0100 Subject: [PATCH 26/41] server: move logging to base implementation, not specialization --- .../reflector/processors/audio_diarization.py | 8 ++++++- .../processors/audio_diarization_modal.py | 22 +++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py index d69f4b80..82c6a553 100644 --- a/server/reflector/processors/audio_diarization.py +++ b/server/reflector/processors/audio_diarization.py @@ -7,7 +7,13 @@ class AudioDiarizationProcessor(Processor): OUTPUT_TYPE = TitleSummary async def _push(self, data: AudioDiarizationInput): - diarization = await self._diarize(data) + try: + self.logger.info("Diarization started", audio_file_url=data.audio_url) + diarization = await self._diarize(data) + self.logger.info("Diarization finished") + except Exception: + self.logger.exception("Diarization failed after retrying") + raise # now reapply speaker to topics (if any) # topics is a list[BaseModel] with an attribute words diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index 3c8d1b45..52be7c5d 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -22,21 +22,15 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor): "audio_file_url": data.audio_url, "timestamp": 0, } - self.logger.info("Diarization started", audio_file_url=data.audio_url) async with httpx.AsyncClient() as client: - try: - response = await client.post( - self.diarization_url, - headers=self.headers, - params=params, - timeout=None, - ) - response.raise_for_status() - self.logger.info("Diarization finished") - return response.json()["text"] - except Exception: - self.logger.exception("Diarization failed after retrying") - raise + response = await client.post( + self.diarization_url, + headers=self.headers, + params=params, + timeout=None, + ) + response.raise_for_status() + return response.json()["text"] AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) From 2e738e9f17dc8bec18aa9fae153808d5401ddf60 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 17:31:47 +0100 Subject: [PATCH 27/41] server: remove unused/old post_transcript.py --- server/reflector/worker/post_transcript.py | 170 --------------------- 1 file changed, 170 deletions(-) delete mode 100644 server/reflector/worker/post_transcript.py diff --git a/server/reflector/worker/post_transcript.py b/server/reflector/worker/post_transcript.py deleted file mode 100644 index 1cfbc664..00000000 --- a/server/reflector/worker/post_transcript.py +++ /dev/null @@ -1,170 +0,0 @@ -from reflector.logger import logger -from reflector.processors import ( - Pipeline, - Processor, - TranscriptFinalLongSummaryProcessor, - TranscriptFinalShortSummaryProcessor, - TranscriptFinalTitleProcessor, -) -from reflector.processors.base import BroadcastProcessor -from reflector.processors.types import ( - FinalLongSummary, - FinalShortSummary, - FinalTitle, - TitleSummary, -) -from reflector.processors.types import Transcript as ProcessorTranscript -from reflector.tasks.worker import celery -from reflector.views.rtc_offer import PipelineEvent, TranscriptionContext -from reflector.views.transcripts import Transcript, transcripts_controller - - -class TranscriptAudioDiarizationProcessor(Processor): - INPUT_TYPE = Transcript - OUTPUT_TYPE = TitleSummary - - async def _push(self, data: Transcript): - # Gather diarization data - diarization = [ - {"start": 0.0, "stop": 4.9, "speaker": 2}, - {"start": 5.6, "stop": 6.7, "speaker": 2}, - {"start": 7.3, "stop": 8.9, "speaker": 2}, - {"start": 7.3, "stop": 7.9, "speaker": 0}, - {"start": 9.4, "stop": 11.2, "speaker": 2}, - {"start": 9.7, "stop": 10.0, "speaker": 0}, - {"start": 10.0, "stop": 10.1, "speaker": 0}, - {"start": 11.7, "stop": 16.1, "speaker": 2}, - {"start": 11.8, "stop": 12.1, "speaker": 1}, - {"start": 16.4, "stop": 21.0, "speaker": 2}, - {"start": 21.1, "stop": 22.6, "speaker": 2}, - {"start": 24.7, "stop": 31.9, "speaker": 2}, - {"start": 32.0, "stop": 32.8, "speaker": 1}, - {"start": 33.4, "stop": 37.8, "speaker": 2}, - {"start": 37.9, "stop": 40.3, "speaker": 0}, - {"start": 39.2, "stop": 40.4, "speaker": 2}, - {"start": 40.7, "stop": 41.4, "speaker": 0}, - {"start": 41.6, "stop": 45.7, "speaker": 2}, - {"start": 46.4, "stop": 53.1, "speaker": 2}, - {"start": 53.6, "stop": 56.5, "speaker": 2}, - {"start": 54.9, "stop": 75.4, "speaker": 1}, - {"start": 57.3, "stop": 58.0, "speaker": 2}, - {"start": 65.7, "stop": 66.0, "speaker": 2}, - {"start": 75.8, "stop": 78.8, "speaker": 1}, - {"start": 79.0, "stop": 82.6, "speaker": 1}, - {"start": 83.2, "stop": 83.3, "speaker": 1}, - {"start": 84.5, "stop": 94.3, "speaker": 1}, - {"start": 95.1, "stop": 100.7, "speaker": 1}, - {"start": 100.7, "stop": 102.0, "speaker": 0}, - {"start": 100.7, "stop": 101.8, "speaker": 1}, - {"start": 102.0, "stop": 103.0, "speaker": 1}, - {"start": 103.0, "stop": 103.7, "speaker": 0}, - {"start": 103.7, "stop": 103.8, "speaker": 1}, - {"start": 103.8, "stop": 113.9, "speaker": 0}, - {"start": 114.7, "stop": 117.0, "speaker": 0}, - {"start": 117.0, "stop": 117.4, "speaker": 1}, - ] - - # now reapply speaker to topics (if any) - # topics is a list[BaseModel] with an attribute words - # words is a list[BaseModel] with text, start and speaker attribute - - # mutate in place - for topic in data.topics: - for word in topic.words: - for d in diarization: - if d["start"] <= word.start <= d["stop"]: - word.speaker = d["speaker"] - - topics = data.topics[:] - - await transcripts_controller.update( - data, - { - "topics": [topic.model_dump(mode="json") for topic in data.topics], - }, - ) - - # emit them - for topic in topics: - transcript = ProcessorTranscript(words=topic.words) - await self.emit( - TitleSummary( - title=topic.title, - summary=topic.summary, - timestamp=topic.timestamp, - duration=0, - transcript=transcript, - ) - ) - - -@celery.task(name="post_transcript") -async def post_transcript_pipeline(transcript_id: str): - # get transcript - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - logger.error("Transcript not found", transcript_id=transcript_id) - return - - ctx = TranscriptionContext(logger=logger.bind(transcript_id=transcript_id)) - event_callback = None - event_callback_args = None - - async def on_final_short_summary(summary: FinalShortSummary): - ctx.logger.info("FinalShortSummary", final_short_summary=summary) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_SHORT_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_long_summary(summary: FinalLongSummary): - ctx.logger.info("FinalLongSummary", final_summary=summary) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_LONG_SUMMARY, - args=event_callback_args, - data=summary, - ) - - async def on_final_title(title: FinalTitle): - ctx.logger.info("FinalTitle", final_title=title) - - # send to callback (eg. websocket) - if event_callback: - await event_callback( - event=PipelineEvent.FINAL_TITLE, - args=event_callback_args, - data=title, - ) - - ctx.logger.info("Starting pipeline (diarization)") - ctx.pipeline = Pipeline( - TranscriptAudioDiarizationProcessor(), - BroadcastProcessor( - processors=[ - TranscriptFinalTitleProcessor.as_threaded(), - TranscriptFinalLongSummaryProcessor.as_threaded(), - TranscriptFinalShortSummaryProcessor.as_threaded(), - ] - ), - ) - - await ctx.pipeline.push(transcript) - await ctx.pipeline.flush() - - -if __name__ == "__main__": - import argparse - import asyncio - - parser = argparse.ArgumentParser() - parser.add_argument("transcript_id", type=str) - args = parser.parse_args() - - asyncio.run(post_transcript_pipeline(args.transcript_id)) From 239fae6189dbd7864bfbe46d3ab311d8324c7755 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 19:02:02 +0100 Subject: [PATCH 28/41] hotfix/server: add migration script to migrate transcript field to text --- server/migrations/versions/9920ecfe2735_.py | 81 +++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 server/migrations/versions/9920ecfe2735_.py diff --git a/server/migrations/versions/9920ecfe2735_.py b/server/migrations/versions/9920ecfe2735_.py new file mode 100644 index 00000000..da718af0 --- /dev/null +++ b/server/migrations/versions/9920ecfe2735_.py @@ -0,0 +1,81 @@ +"""empty message + +Revision ID: 9920ecfe2735 +Revises: 99365b0cd87b +Create Date: 2023-11-02 18:55:17.019498 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import json +from sqlalchemy.sql import table, column +from sqlalchemy import select + + +# revision identifiers, used by Alembic. +revision: str = "9920ecfe2735" +down_revision: Union[str, None] = "99365b0cd87b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "transcript" in topic: + # Rename key 'transcript' to 'text' + topic["text"] = topic.pop("transcript") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) + + +def downgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "text" in topic: + # Rename key 'text' back to 'transcript' + topic["transcript"] = topic.pop("text") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) From 3424550ea9f477d78a7b789b2b98fd6d210453cd Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 19:06:30 +0100 Subject: [PATCH 29/41] hotfix/server: add id in GetTranscriptTopic for the frontend to work --- server/reflector/views/transcripts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index f83bc6de..f724bcdc 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -117,6 +117,7 @@ class GetTranscriptSegmentTopic(BaseModel): class GetTranscriptTopic(BaseModel): + id: str title: str summary: str timestamp: float @@ -149,6 +150,7 @@ class GetTranscriptTopic(BaseModel): for segment in transcript.as_segments() ] return cls( + id=topic.id, title=topic.title, summary=topic.summary, timestamp=topic.timestamp, From c87c30d33952709a19ff6444b4fb788893f790bb Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 19:09:13 +0100 Subject: [PATCH 30/41] hotfix/server: add follow_redirect on modal --- server/reflector/llm/llm_modal.py | 1 + server/reflector/processors/audio_diarization_modal.py | 1 + server/reflector/processors/audio_transcript_modal.py | 1 + server/reflector/processors/transcript_translator.py | 1 + 4 files changed, 4 insertions(+) diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 220730e5..4b81c5a0 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -47,6 +47,7 @@ class ModalLLM(LLM): json=json_payload, timeout=self.timeout, retry_timeout=60 * 5, + follow_redirects=True, ) response.raise_for_status() text = response.json()["text"] diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index 52be7c5d..53de2501 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -28,6 +28,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor): headers=self.headers, params=params, timeout=None, + follow_redirects=True, ) response.raise_for_status() return response.json()["text"] diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 23c9d74e..0ca4710f 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -41,6 +41,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): timeout=self.timeout, headers=self.headers, params=json_payload, + follow_redirects=True, ) self.logger.debug( diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py index 77b8f5be..905ea423 100644 --- a/server/reflector/processors/transcript_translator.py +++ b/server/reflector/processors/transcript_translator.py @@ -50,6 +50,7 @@ class TranscriptTranslatorProcessor(Processor): headers=self.headers, params=json_payload, timeout=self.timeout, + follow_redirects=True, ) response.raise_for_status() result = response.json()["text"] From 37f6fe634544bbf38345d290afe21872ee7016ab Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 19:17:34 +0100 Subject: [PATCH 31/41] server: rename migration script for readability --- ...20ecfe2735_.py => 9920ecfe2735_rename_transcript_to_text.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename server/migrations/versions/{9920ecfe2735_.py => 9920ecfe2735_rename_transcript_to_text.py} (97%) diff --git a/server/migrations/versions/9920ecfe2735_.py b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py similarity index 97% rename from server/migrations/versions/9920ecfe2735_.py rename to server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py index da718af0..90ff85ac 100644 --- a/server/migrations/versions/9920ecfe2735_.py +++ b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py @@ -1,4 +1,4 @@ -"""empty message +"""Migration transcript to text field in transcripts table Revision ID: 9920ecfe2735 Revises: 99365b0cd87b From 9642d0fd1e201fc804d61a35cbeb0c5a5217e4ab Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 19:40:45 +0100 Subject: [PATCH 32/41] hotfix/server: fix duplication of topics --- server/reflector/db/transcripts.py | 6 +++--- server/reflector/pipelines/main_live_pipeline.py | 9 +++++++-- server/reflector/processors/types.py | 6 +++++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 89025d53..5d190bdc 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -106,9 +106,9 @@ class Transcript(BaseModel): return ev def upsert_topic(self, topic: TranscriptTopic): - existing_topic = next((t for t in self.topics if t.id == topic.id), None) - if existing_topic: - existing_topic.update_from(topic) + index = next((i for i, t in enumerate(self.topics) if t.id == topic.id), None) + if index is not None: + self.topics[index] = topic else: self.topics.append(topic) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 88e1bffd..bf11bdf3 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -46,7 +46,9 @@ from reflector.processors import ( TranscriptTranslatorProcessor, ) from reflector.processors.types import AudioDiarizationInput -from reflector.processors.types import TitleSummary as TitleSummaryProcessorType +from reflector.processors.types import ( + TitleSummaryWithId as TitleSummaryWithIdProcessorType, +) from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager @@ -163,6 +165,8 @@ class PipelineMainBase(PipelineRunner): text=data.transcript.text, words=data.transcript.words, ) + if isinstance(data, TitleSummaryWithIdProcessorType): + topic.id = data.id async with self.transaction(): transcript = await self.get_transcript() await transcripts_controller.upsert_topic(transcript, topic) @@ -302,7 +306,8 @@ class PipelineMainDiarization(PipelineMainBase): # XXX translation is lost when converting our data model to the processor model transcript = await self.get_transcript() topics = [ - TitleSummaryProcessorType( + TitleSummaryWithIdProcessorType( + id=topic.id, title=topic.title, summary=topic.summary, timestamp=topic.timestamp, diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index b67f84b9..312f5433 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -167,6 +167,10 @@ class TitleSummary(BaseModel): return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" +class TitleSummaryWithId(TitleSummary): + id: str + + class FinalLongSummary(BaseModel): long_summary: str duration: float @@ -386,4 +390,4 @@ class TranslationLanguages(BaseModel): class AudioDiarizationInput(BaseModel): audio_url: str - topics: list[TitleSummary] + topics: list[TitleSummaryWithId] From eb76cd9bcd4540562ca3e658466904571a638406 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 19:54:44 +0100 Subject: [PATCH 33/41] server/www: rename topic text field to transcript This aleviate the current issue with vercel deployment --- ...27dcb099_rename_back_text_to_transcript.py | 80 +++++++++++++++++++ .../9920ecfe2735_rename_transcript_to_text.py | 1 - server/reflector/db/transcripts.py | 2 +- .../reflector/pipelines/main_live_pipeline.py | 2 +- server/reflector/views/transcripts.py | 4 +- server/tests/test_transcripts_rtc_ws.py | 4 +- www/app/[domain]/transcripts/topicList.tsx | 2 +- .../[domain]/transcripts/webSocketTypes.ts | 2 +- www/app/api/models/GetTranscriptTopic.ts | 17 +++- 9 files changed, 101 insertions(+), 13 deletions(-) create mode 100644 server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py diff --git a/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py b/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py new file mode 100644 index 00000000..dffe6fa1 --- /dev/null +++ b/server/migrations/versions/38a927dcb099_rename_back_text_to_transcript.py @@ -0,0 +1,80 @@ +"""rename back text to transcript + +Revision ID: 38a927dcb099 +Revises: 9920ecfe2735 +Create Date: 2023-11-02 19:53:09.116240 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, column +from sqlalchemy import select + + +# revision identifiers, used by Alembic. +revision: str = '38a927dcb099' +down_revision: Union[str, None] = '9920ecfe2735' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "text" in topic: + # Rename key 'text' back to 'transcript' + topic["transcript"] = topic.pop("text") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) + + +def downgrade() -> None: + # bind the engine + bind = op.get_bind() + + # Reflect the table + transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) + + # Select all rows from the transcript table + results = bind.execute(select([transcript.c.id, transcript.c.topics])) + + for row in results: + transcript_id = row["id"] + topics_json = row["topics"] + + # Process each topic in the topics JSON array + updated_topics = [] + for topic in topics_json: + if "transcript" in topic: + # Rename key 'transcript' to 'text' + topic["text"] = topic.pop("transcript") + updated_topics.append(topic) + + # Update the transcript table + bind.execute( + transcript.update() + .where(transcript.c.id == transcript_id) + .values(topics=updated_topics) + ) diff --git a/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py index 90ff85ac..caecaefd 100644 --- a/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py +++ b/server/migrations/versions/9920ecfe2735_rename_transcript_to_text.py @@ -9,7 +9,6 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa -import json from sqlalchemy.sql import table, column from sqlalchemy import select diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 5d190bdc..6ac2e32a 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -63,7 +63,7 @@ class TranscriptTopic(BaseModel): summary: str timestamp: float duration: float | None = 0 - text: str | None = None + transcript: str | None = None words: list[ProcessorWord] = [] diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index bf11bdf3..477b0ce9 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -162,7 +162,7 @@ class PipelineMainBase(PipelineRunner): title=data.title, summary=data.summary, timestamp=data.timestamp, - text=data.transcript.text, + transcript=data.transcript.text, words=data.transcript.words, ) if isinstance(data, TitleSummaryWithIdProcessorType): diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index f724bcdc..77c3e149 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -121,7 +121,7 @@ class GetTranscriptTopic(BaseModel): title: str summary: str timestamp: float - text: str + transcript: str segments: list[GetTranscriptSegmentTopic] = [] @classmethod @@ -154,7 +154,7 @@ class GetTranscriptTopic(BaseModel): title=topic.title, summary=topic.summary, timestamp=topic.timestamp, - text=text, + transcript=text, segments=segments, ) diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 8f8cac71..413c8b24 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -167,7 +167,7 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["text"].startswith("Hello world.") + assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames @@ -316,7 +316,7 @@ async def test_transcript_rtc_and_websocket_and_fr( ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["text"].startswith("Hello world.") + assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames diff --git a/www/app/[domain]/transcripts/topicList.tsx b/www/app/[domain]/transcripts/topicList.tsx index 56b02e6e..e7454f79 100644 --- a/www/app/[domain]/transcripts/topicList.tsx +++ b/www/app/[domain]/transcripts/topicList.tsx @@ -132,7 +132,7 @@ export function TopicList({ ))} ) : ( - <>{topic.text} + <>{topic.transcript} )} )} diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts index abc67b33..112e7cc0 100644 --- a/www/app/[domain]/transcripts/webSocketTypes.ts +++ b/www/app/[domain]/transcripts/webSocketTypes.ts @@ -9,7 +9,7 @@ export type Topic = { title: string; summary: string; id: string; - text: string; + transcript: string; segments: SegmentTopic[]; }; diff --git a/www/app/api/models/GetTranscriptTopic.ts b/www/app/api/models/GetTranscriptTopic.ts index 7a7d4c90..460b8b39 100644 --- a/www/app/api/models/GetTranscriptTopic.ts +++ b/www/app/api/models/GetTranscriptTopic.ts @@ -19,6 +19,12 @@ import { exists, mapValues } from "../runtime"; * @interface GetTranscriptTopic */ export interface GetTranscriptTopic { + /** + * + * @type {any} + * @memberof GetTranscriptTopic + */ + id: any | null; /** * * @type {any} @@ -42,7 +48,7 @@ export interface GetTranscriptTopic { * @type {any} * @memberof GetTranscriptTopic */ - text: any | null; + transcript: any | null; /** * * @type {any} @@ -56,10 +62,11 @@ export interface GetTranscriptTopic { */ export function instanceOfGetTranscriptTopic(value: object): boolean { let isInstance = true; + isInstance = isInstance && "id" in value; isInstance = isInstance && "title" in value; isInstance = isInstance && "summary" in value; isInstance = isInstance && "timestamp" in value; - isInstance = isInstance && "text" in value; + isInstance = isInstance && "transcript" in value; return isInstance; } @@ -76,10 +83,11 @@ export function GetTranscriptTopicFromJSONTyped( return json; } return { + id: json["id"], title: json["title"], summary: json["summary"], timestamp: json["timestamp"], - text: json["text"], + transcript: json["transcript"], segments: !exists(json, "segments") ? undefined : json["segments"], }; } @@ -94,10 +102,11 @@ export function GetTranscriptTopicToJSON( return null; } return { + id: value.id, title: value.title, summary: value.summary, timestamp: value.timestamp, - text: value.text, + transcript: value.transcript, segments: value.segments, }; } From aee959369f3a79859fa2cf992f6c7a1178c13264 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 20:08:21 +0100 Subject: [PATCH 34/41] hotfix/server: correctly load old topic --- server/reflector/views/transcripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 77c3e149..7d79bf72 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -129,7 +129,7 @@ class GetTranscriptTopic(BaseModel): if not topic.words: # In previous version, words were missing # Just output a segment with speaker 0 - text = topic.text + text = topic.transcript segments = [ GetTranscriptSegmentTopic( text=topic.text, From 6f3e3741e788dbe0487d6e2a5e4f536428f2464b Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 20:12:57 +0100 Subject: [PATCH 35/41] hotfix/server: fix crash --- server/reflector/views/transcripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 7d79bf72..e3668ecb 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -132,7 +132,7 @@ class GetTranscriptTopic(BaseModel): text = topic.transcript segments = [ GetTranscriptSegmentTopic( - text=topic.text, + text=topic.transcript, start=topic.timestamp, speaker=0, ) From c5893e0391a1d7b6bcd7dd9c7b0dc745e03e8a24 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 20:34:53 +0100 Subject: [PATCH 36/41] hotfix/server: do not pass a token for diarization/mp3 if there is no user When decoding the token, if it is invalid (sub cannot be None), it just fail --- server/reflector/pipelines/main_live_pipeline.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 477b0ce9..b2bc51ea 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -322,15 +322,19 @@ class PipelineMainDiarization(PipelineMainBase): # from the diarization processor from reflector.views.transcripts import create_access_token - token = create_access_token( - {"sub": transcript.user_id}, - expires_delta=timedelta(minutes=15), - ) path = app.url_path_for( "transcript_get_audio_mp3", transcript_id=transcript.id, ) - url = f"{settings.BASE_URL}{path}?token={token}" + url = f"{settings.BASE_URL}{path}" + if transcript.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": transcript.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" audio_diarization_input = AudioDiarizationInput( audio_url=url, topics=topics, From b9149d6e68116a9525e3495332928b455e09b7b1 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 21:00:14 +0100 Subject: [PATCH 37/41] server: ensure retry works even with 303 redirection --- server/tests/test_retry_decorator.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/server/tests/test_retry_decorator.py b/server/tests/test_retry_decorator.py index 22729eac..c60a490f 100644 --- a/server/tests/test_retry_decorator.py +++ b/server/tests/test_retry_decorator.py @@ -1,3 +1,4 @@ +import asyncio import pytest import httpx from reflector.utils.retry import ( @@ -8,6 +9,31 @@ from reflector.utils.retry import ( ) +@pytest.mark.asyncio +async def test_retry_redirect(httpx_mock): + async def custom_response(request: httpx.Request): + if request.url.path == "/hello": + await asyncio.sleep(1) + return httpx.Response( + status_code=303, headers={"location": "https://test_url/redirected"} + ) + elif request.url.path == "/redirected": + return httpx.Response(status_code=200, json={"hello": "world"}) + else: + raise Exception("Unexpected path") + + httpx_mock.add_callback(custom_response) + async with httpx.AsyncClient() as client: + # timeout should not triggered, as it will end up ok + # even though the first request is a 303 and took more that 0.5 + resp = await retry(client.get)( + "https://test_url/hello", + retry_timeout=0.5, + follow_redirects=True, + ) + assert resp.json() == {"hello": "world"} + + @pytest.mark.asyncio async def test_retry_httpx(httpx_mock): # this code should be force a retry From da926b79a00042875868c73a8b2aacfc1a548434 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 3 Nov 2023 10:34:14 +0100 Subject: [PATCH 38/41] Update README.md - trigger vercel build after outage --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 150c4575..cb75c76b 100644 --- a/README.md +++ b/README.md @@ -169,3 +169,4 @@ poetry run python -m reflector.tools.process path/to/audio.wav ## AI Models *(Documentation for this section is pending.)* + From 5f773a1e82bdf1d4afdb81d3c0042cb63609e0ed Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 3 Nov 2023 10:36:16 +0100 Subject: [PATCH 39/41] Update forbidden.tsx - edit to trigger vercel build --- www/pages/forbidden.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/www/pages/forbidden.tsx b/www/pages/forbidden.tsx index 31a746fc..ada3d424 100644 --- a/www/pages/forbidden.tsx +++ b/www/pages/forbidden.tsx @@ -1,7 +1,7 @@ import type { NextPage } from "next"; const Forbidden: NextPage = () => { - return

Sorry, you are not authorized to access this page.

; + return

Sorry, you are not authorized to access this page

; }; export default Forbidden; From 22b1ce9fd208f3603c2923181aa6fb1a211dd455 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 3 Nov 2023 11:22:10 +0100 Subject: [PATCH 40/41] www: fix build with text->transcript and duplication of topics --- www/app/[domain]/transcripts/useWebSockets.ts | 38 +++++++++++++------ .../[domain]/transcripts/webSocketTypes.ts | 15 +------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index 8196749e..5610c2a4 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -56,7 +56,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { timestamp: 10, summary: "This is test topic 1", title: "Topic 1: Introduction to Quantum Mechanics", - text: "A brief overview of quantum mechanics and its principles.", + transcript: + "A brief overview of quantum mechanics and its principles.", segments: [ { speaker: 1, @@ -96,7 +97,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { timestamp: 20, summary: "This is test topic 2", title: "Topic 2: Machine Learning Algorithms", - text: "Understanding the different types of machine learning algorithms.", + transcript: + "Understanding the different types of machine learning algorithms.", segments: [ { speaker: 1, @@ -115,7 +117,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { timestamp: 30, summary: "This is test topic 3", title: "Topic 3: Mental Health Awareness", - text: "Ways to improve mental health and reduce stigma.", + transcript: "Ways to improve mental health and reduce stigma.", segments: [ { speaker: 1, @@ -134,7 +136,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { timestamp: 40, summary: "This is test topic 4", title: "Topic 4: Basics of Productivity", - text: "Tips and tricks to increase daily productivity.", + transcript: "Tips and tricks to increase daily productivity.", segments: [ { speaker: 1, @@ -153,7 +155,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { timestamp: 50, summary: "This is test topic 5", title: "Topic 5: Future of Aviation", - text: "Exploring the advancements and possibilities in aviation.", + transcript: + "Exploring the advancements and possibilities in aviation.", segments: [ { speaker: 1, @@ -182,7 +185,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 1", title: "Topic 1: Introduction to Quantum Mechanics, a brief overview of quantum mechanics and its principles.", - text: "A brief overview of quantum mechanics and its principles.", + transcript: + "A brief overview of quantum mechanics and its principles.", segments: [ { speaker: 1, @@ -202,7 +206,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 2", title: "Topic 2: Machine Learning Algorithms, understanding the different types of machine learning algorithms.", - text: "Understanding the different types of machine learning algorithms.", + transcript: + "Understanding the different types of machine learning algorithms.", segments: [ { speaker: 1, @@ -222,7 +227,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 3", title: "Topic 3: Mental Health Awareness, ways to improve mental health and reduce stigma.", - text: "Ways to improve mental health and reduce stigma.", + transcript: "Ways to improve mental health and reduce stigma.", segments: [ { speaker: 1, @@ -242,7 +247,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 4", title: "Topic 4: Basics of Productivity, tips and tricks to increase daily productivity.", - text: "Tips and tricks to increase daily productivity.", + transcript: "Tips and tricks to increase daily productivity.", segments: [ { speaker: 1, @@ -262,7 +267,8 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { summary: "This is test topic 5", title: "Topic 5: Future of Aviation, exploring the advancements and possibilities in aviation.", - text: "Exploring the advancements and possibilities in aviation.", + transcript: + "Exploring the advancements and possibilities in aviation.", segments: [ { speaker: 1, @@ -308,7 +314,17 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { break; case "TOPIC": - setTopics((prevTopics) => [...prevTopics, message.data]); + setTopics((prevTopics) => { + const topic = message.data as Topic; + const index = prevTopics.findIndex( + (prevTopic) => prevTopic.id === topic.id, + ); + if (index >= 0) { + prevTopics[index] = topic; + return prevTopics; + } + return [...prevTopics, topic]; + }); console.debug("TOPIC event:", message.data); break; diff --git a/www/app/[domain]/transcripts/webSocketTypes.ts b/www/app/[domain]/transcripts/webSocketTypes.ts index 112e7cc0..edd35eb6 100644 --- a/www/app/[domain]/transcripts/webSocketTypes.ts +++ b/www/app/[domain]/transcripts/webSocketTypes.ts @@ -1,17 +1,6 @@ -export type SegmentTopic = { - speaker: number; - start: number; - text: string; -}; +import { GetTranscriptTopic } from "../../api"; -export type Topic = { - timestamp: number; - title: string; - summary: string; - id: string; - transcript: string; - segments: SegmentTopic[]; -}; +export type Topic = GetTranscriptTopic; export type Transcript = { text: string; From 907f4be67ab8f115603678fde3f3e7a064bc6f8c Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 2 Nov 2023 11:46:45 +0100 Subject: [PATCH 41/41] www: fix mp3 download while authenticated --- .../transcripts/[transcriptId]/page.tsx | 3 ++ www/app/[domain]/transcripts/recorder.tsx | 15 ++++-- www/app/[domain]/transcripts/useMp3.ts | 54 ++++++++++++++----- www/package.json | 2 +- www/yarn.lock | 8 +-- 5 files changed, 60 insertions(+), 22 deletions(-) diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 3e30b97f..8e2184a0 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -4,6 +4,7 @@ import getApi from "../../../lib/getApi"; import useTranscript from "../useTranscript"; import useTopics from "../useTopics"; import useWaveform from "../useWaveform"; +import useMp3 from "../useMp3"; import { TopicList } from "../topicList"; import Recorder from "../recorder"; import { Topic } from "../webSocketTypes"; @@ -31,6 +32,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { const waveform = useWaveform(api, transcriptId); const useActiveTopic = useState(null); const requireLogin = featureEnabled("requireLogin"); + const mp3 = useMp3(api, transcriptId); useEffect(() => { if (requireLogin && !isAuthenticated) return; @@ -70,6 +72,7 @@ export default function TranscriptDetails(details: TranscriptDetails) { waveform={waveform?.waveform} isPastMeeting={true} transcriptId={transcript?.response?.id} + mp3Blob={mp3.blob} /> )} diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx index 401b6c9e..d50c90e3 100644 --- a/www/app/[domain]/transcripts/recorder.tsx +++ b/www/app/[domain]/transcripts/recorder.tsx @@ -29,6 +29,7 @@ type RecorderProps = { waveform?: AudioWaveform | null; isPastMeeting: boolean; transcriptId?: string | null; + mp3Blob?: Blob | null; }; export default function Recorder(props: RecorderProps) { @@ -107,11 +108,7 @@ export default function Recorder(props: RecorderProps) { if (waveformRef.current) { const _wavesurfer = WaveSurfer.create({ container: waveformRef.current, - url: props.transcriptId - ? `${process.env.NEXT_PUBLIC_API_URL}/v1/transcripts/${props.transcriptId}/audio/mp3` - : undefined, peaks: props.waveform?.data, - hideScrollbar: true, autoCenter: true, barWidth: 2, @@ -145,6 +142,10 @@ export default function Recorder(props: RecorderProps) { if (props.isPastMeeting) _wavesurfer.toggleInteraction(true); + if (props.mp3Blob) { + _wavesurfer.loadBlob(props.mp3Blob); + } + setWavesurfer(_wavesurfer); return () => { @@ -156,6 +157,12 @@ export default function Recorder(props: RecorderProps) { } }, []); + useEffect(() => { + if (!wavesurfer) return; + if (!props.mp3Blob) return; + wavesurfer.loadBlob(props.mp3Blob); + }, [props.mp3Blob]); + useEffect(() => { topicsRef.current = props.topics; if (!isRecording) renderMarkers(); diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts index 8bccf903..b7677180 100644 --- a/www/app/[domain]/transcripts/useMp3.ts +++ b/www/app/[domain]/transcripts/useMp3.ts @@ -1,36 +1,64 @@ -import { useEffect, useState } from "react"; +import { useContext, useEffect, useState } from "react"; import { DefaultApi, - V1TranscriptGetAudioMp3Request, + // V1TranscriptGetAudioMp3Request, } from "../../api/apis/DefaultApi"; import {} from "../../api"; import { useError } from "../../(errors)/errorContext"; +import { DomainContext } from "../domainContext"; type Mp3Response = { url: string | null; + blob: Blob | null; loading: boolean; error: Error | null; }; const useMp3 = (api: DefaultApi, id: string): Mp3Response => { const [url, setUrl] = useState(null); + const [blob, setBlob] = useState(null); const [loading, setLoading] = useState(false); const [error, setErrorState] = useState(null); const { setError } = useError(); + const { api_url } = useContext(DomainContext); const getMp3 = (id: string) => { - if (!id) throw new Error("Transcript ID is required to get transcript Mp3"); + if (!id) return; setLoading(true); - const requestParameters: V1TranscriptGetAudioMp3Request = { - transcriptId: id, - }; - api - .v1TranscriptGetAudioMp3(requestParameters) - .then((result) => { - setUrl(result); - setLoading(false); - console.debug("Transcript Mp3 loaded:", result); + // XXX Current API interface does not output a blob, we need to to is manually + // const requestParameters: V1TranscriptGetAudioMp3Request = { + // transcriptId: id, + // }; + // api + // .v1TranscriptGetAudioMp3(requestParameters) + // .then((result) => { + // setUrl(result); + // setLoading(false); + // console.debug("Transcript Mp3 loaded:", result); + // }) + // .catch((err) => { + // setError(err); + // setErrorState(err); + // }); + const localUrl = `${api_url}/v1/transcripts/${id}/audio/mp3`; + if (localUrl == url) return; + const headers = new Headers(); + + if (api.configuration.configuration.accessToken) { + headers.set("Authorization", api.configuration.configuration.accessToken); + } + + fetch(localUrl, { + method: "GET", + headers, + }) + .then((response) => { + setUrl(localUrl); + response.blob().then((blob) => { + setBlob(blob); + setLoading(false); + }); }) .catch((err) => { setError(err); @@ -42,7 +70,7 @@ const useMp3 = (api: DefaultApi, id: string): Mp3Response => { getMp3(id); }, [id]); - return { url, loading, error }; + return { url, blob, loading, error }; }; export default useMp3; diff --git a/www/package.json b/www/package.json index edbc0790..55c7df73 100644 --- a/www/package.json +++ b/www/package.json @@ -35,7 +35,7 @@ "supports-color": "^9.4.0", "tailwindcss": "^3.3.2", "typescript": "^5.1.6", - "wavesurfer.js": "^7.0.3" + "wavesurfer.js": "^7.4.2" }, "main": "index.js", "repository": "https://github.com/Monadical-SAS/reflector-ui.git", diff --git a/www/yarn.lock b/www/yarn.lock index a67822be..8ec03382 100644 --- a/www/yarn.lock +++ b/www/yarn.lock @@ -2638,10 +2638,10 @@ watchpack@2.4.0: glob-to-regexp "^0.4.1" graceful-fs "^4.1.2" -wavesurfer.js@^7.0.3: - version "7.0.3" - resolved "https://registry.npmjs.org/wavesurfer.js/-/wavesurfer.js-7.0.3.tgz" - integrity sha512-gJ3P+Bd3Q4E8qETjjg0pneaVqm2J7jegG2Cc6vqEF5YDDKQ3m8sKsvVfgVhJkacKkO9jFAGDu58Hw4zLr7xD0A== +wavesurfer.js@^7.4.2: + version "7.4.2" + resolved "https://registry.yarnpkg.com/wavesurfer.js/-/wavesurfer.js-7.4.2.tgz#59f5c87193d4eeeb199858688ddac1ad7ba86b3a" + integrity sha512-4pNQ1porOCUBYBmd2F1TqVuBnB2wBPipaw2qI920zYLuPnada0Rd1CURgh8HRuPGKxijj2iyZDFN2UZwsaEuhA== wcwidth@>=1.0.1, wcwidth@^1.0.1: version "1.0.1"