diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py index 82c6a553..69eab5b7 100644 --- a/server/reflector/processors/audio_diarization.py +++ b/server/reflector/processors/audio_diarization.py @@ -1,5 +1,5 @@ from reflector.processors.base import Processor -from reflector.processors.types import AudioDiarizationInput, TitleSummary +from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word class AudioDiarizationProcessor(Processor): @@ -19,12 +19,12 @@ class AudioDiarizationProcessor(Processor): # topics is a list[BaseModel] with an attribute words # words is a list[BaseModel] with text, start and speaker attribute - # mutate in place - for topic in data.topics: - for word in topic.transcript.words: - for d in diarization: - if d["start"] <= word.start <= d["end"]: - word.speaker = d["speaker"] + # create a view of words based on topics + # the current algorithm is using words index, we cannot use a generator + words = list(self.iter_words_from_topics(data.topics)) + + # assign speaker to words (mutate the words list) + self.assign_speaker(words, diarization) # emit them for topic in data.topics: @@ -32,3 +32,150 @@ class AudioDiarizationProcessor(Processor): async def _diarize(self, data: AudioDiarizationInput): raise NotImplementedError + + def assign_speaker(self, words: list[Word], diarization: list[dict]): + self._diarization_remove_overlap(diarization) + self._diarization_remove_segment_without_words(words, diarization) + self._diarization_merge_same_speaker(words, diarization) + self._diarization_assign_speaker(words, diarization) + + def iter_words_from_topics(self, topics: TitleSummary): + for topic in topics: + for word in topic.transcript.words: + yield word + + def is_word_continuation(self, word_prev, word): + """ + Return True if the word is a continuation of the previous word + by checking if the previous word is ending with a punctuation + or if the current word is starting with a capital letter + """ + # is word_prev ending with a punctuation ? + if word_prev.text and word_prev.text[-1] in ".?!": + return False + elif word.text and word.text[0].isupper(): + return False + return True + + def _diarization_remove_overlap(self, diarization: list[dict]): + """ + Remove overlap in diarization results + + When using a diarization algorithm, it's possible to have overlapping segments + This function remove the overlap by keeping the longest segment + + Warning: this function mutate the diarization list + """ + # remove overlap by keeping the longest segment + diarization_idx = 0 + while diarization_idx < len(diarization) - 1: + d = diarization[diarization_idx] + dnext = diarization[diarization_idx + 1] + if d["end"] > dnext["start"]: + # remove the shortest segment + if d["end"] - d["start"] > dnext["end"] - dnext["start"]: + # remove next segment + diarization.pop(diarization_idx + 1) + else: + # remove current segment + diarization.pop(diarization_idx) + else: + diarization_idx += 1 + + def _diarization_remove_segment_without_words( + self, words: list[Word], diarization: list[dict] + ): + """ + Remove diarization segments without words + + Warning: this function mutate the diarization list + """ + # count the number of words for each diarization segment + diarization_count = [] + for d in diarization: + start = d["start"] + end = d["end"] + count = 0 + for word in words: + if start <= word.start < end: + count += 1 + elif start < word.end <= end: + count += 1 + diarization_count.append(count) + + # remove diarization segments with no words + diarization_idx = 0 + while diarization_idx < len(diarization): + if diarization_count[diarization_idx] == 0: + diarization.pop(diarization_idx) + diarization_count.pop(diarization_idx) + else: + diarization_idx += 1 + + def _diarization_merge_same_speaker( + self, words: list[Word], diarization: list[dict] + ): + """ + Merge diarization contigous segments with the same speaker + + Warning: this function mutate the diarization list + """ + # merge segment with same speaker + diarization_idx = 0 + while diarization_idx < len(diarization) - 1: + d = diarization[diarization_idx] + dnext = diarization[diarization_idx + 1] + if d["speaker"] == dnext["speaker"]: + diarization[diarization_idx]["end"] = dnext["end"] + diarization.pop(diarization_idx + 1) + else: + diarization_idx += 1 + + def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]): + """ + Assign speaker to words based on diarization + + Warning: this function mutate the words list + """ + + word_idx = 0 + last_speaker = None + for d in diarization: + start = d["start"] + end = d["end"] + speaker = d["speaker"] + + # diarization may start after the first set of words + # in this case, we assign the last speaker + for word in words[word_idx:]: + if word.start < start: + # speaker change, but what make sense for assigning the word ? + # If it's a new sentence, assign with the new speaker + # If it's a continuation, assign with the last speaker + is_continuation = False + if word_idx > 0 and word_idx < len(words) - 1: + is_continuation = self.is_word_continuation( + *words[word_idx - 1 : word_idx + 1] + ) + if is_continuation: + word.speaker = last_speaker + else: + word.speaker = speaker + last_speaker = speaker + word_idx += 1 + else: + break + + # now continue to assign speaker until the word starts after the end + for word in words[word_idx:]: + if start <= word.start < end: + last_speaker = speaker + word.speaker = speaker + word_idx += 1 + elif word.start > end: + break + + # no more diarization available, + # assign last speaker to all words without speaker + for word in words[word_idx:]: + word.speaker = last_speaker diff --git a/server/tests/test_processor_audio_diarization.py b/server/tests/test_processor_audio_diarization.py new file mode 100644 index 00000000..00935a49 --- /dev/null +++ b/server/tests/test_processor_audio_diarization.py @@ -0,0 +1,140 @@ +import pytest +from unittest import mock + + +@pytest.mark.parametrize( + "name,diarization,expected", + [ + [ + "no overlap", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "same speaker", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "A"}, + ], + ["A", "A", "A", "A"], + ], + [ + # first segment is removed because it overlap + # with the second segment, and it is smaller + "overlap at 0.5s", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 0.5, "end": 2.0, "speaker": "B"}, + ], + ["B", "B", "B", "B"], + ], + [ + "junk segment at 0.5s for 0.2s", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 0.5, "end": 0.7, "speaker": "B"}, + {"start": 1, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "start without diarization", + [ + {"start": 0.5, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "end missing diarization", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 1.5, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "continuation of next speaker", + [ + {"start": 0.0, "end": 0.9, "speaker": "A"}, + {"start": 1.5, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "continuation of previous speaker", + [ + {"start": 0.0, "end": 0.5, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "segment without words", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + {"start": 2.0, "end": 3.0, "speaker": "X"}, + ], + ["A", "A", "B", "B"], + ], + ], +) +@pytest.mark.asyncio +async def test_processors_audio_diarization(event_loop, name, diarization, expected): + from reflector.processors.audio_diarization import AudioDiarizationProcessor + from reflector.processors.types import ( + TitleSummaryWithId, + Transcript, + Word, + AudioDiarizationInput, + ) + + # create fake topic + topics = [ + TitleSummaryWithId( + id="1", + title="Title1", + summary="Summary1", + timestamp=0.0, + duration=1.0, + transcript=Transcript( + words=[ + Word(text="Word1", start=0.0, end=0.5), + Word(text="word2.", start=0.5, end=1.0), + ] + ), + ), + TitleSummaryWithId( + id="2", + title="Title2", + summary="Summary2", + timestamp=0.0, + duration=1.0, + transcript=Transcript( + words=[ + Word(text="Word3", start=1.0, end=1.5), + Word(text="word4.", start=1.5, end=2.0), + ] + ), + ), + ] + + diarizer = AudioDiarizationProcessor() + with mock.patch.object(diarizer, "_diarize") as mock_diarize: + mock_diarize.return_value = diarization + + data = AudioDiarizationInput( + audio_url="https://example.com/audio.mp3", + topics=topics, + ) + await diarizer._push(data) + + # check that the speaker has been assigned to the words + assert topics[0].transcript.words[0].speaker == expected[0] + assert topics[0].transcript.words[1].speaker == expected[1] + assert topics[1].transcript.words[0].speaker == expected[2] + assert topics[1].transcript.words[1].speaker == expected[3]