server: move out profanity filter to transcript, and implement segmentation

This commit is contained in:
2023-10-19 21:05:13 +02:00
committed by Mathieu Virbel
parent 0d9f66c097
commit b323254376
6 changed files with 78 additions and 19 deletions

View File

@@ -1,6 +1,4 @@
from profanityfilter import ProfanityFilter
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
from reflector.processors.base import Processor from reflector.processors.base import Processor
from reflector.processors.types import AudioFile, Transcript from reflector.processors.types import AudioFile, Transcript
@@ -40,8 +38,6 @@ class AudioTranscriptProcessor(Processor):
self.m_transcript_call = self.m_transcript_call.labels(name) self.m_transcript_call = self.m_transcript_call.labels(name)
self.m_transcript_success = self.m_transcript_success.labels(name) self.m_transcript_success = self.m_transcript_success.labels(name)
self.m_transcript_failure = self.m_transcript_failure.labels(name) self.m_transcript_failure = self.m_transcript_failure.labels(name)
self.profanity_filter = ProfanityFilter()
self.profanity_filter.set_censor("*")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def _push(self, data: AudioFile): async def _push(self, data: AudioFile):
@@ -60,9 +56,3 @@ class AudioTranscriptProcessor(Processor):
async def _transcript(self, data: AudioFile): async def _transcript(self, data: AudioFile):
raise NotImplementedError raise NotImplementedError
def filter_profanity(self, text: str) -> str:
"""
Remove censored words from the transcript
"""
return self.profanity_filter.censor(text)

View File

@@ -48,10 +48,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
text = result["text"][source_language]
text = self.filter_profanity(text)
transcript = Transcript( transcript = Transcript(
text=text,
words=[ words=[
Word( Word(
text=word["text"], text=word["text"],

View File

@@ -30,7 +30,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
ts = data.timestamp ts = data.timestamp
for segment in segments: for segment in segments:
transcript.text += segment.text
for word in segment.words: for word in segment.words:
transcript.words.append( transcript.words.append(
Word( Word(

View File

@@ -36,7 +36,6 @@ class TranscriptLinerProcessor(Processor):
# cut to the next . # cut to the next .
partial = Transcript(words=[]) partial = Transcript(words=[])
for word in self.transcript.words[:]: for word in self.transcript.words[:]:
partial.text += word.text
partial.words.append(word) partial.words.append(word)
if not self.is_sentence_terminated(word.text): if not self.is_sentence_terminated(word.text):
continue continue

View File

@@ -2,8 +2,12 @@ import io
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
profanity_filter = ProfanityFilter()
profanity_filter.set_censor("*")
class AudioFile(BaseModel): class AudioFile(BaseModel):
name: str name: str
@@ -43,13 +47,29 @@ class Word(BaseModel):
text: str text: str
start: float start: float
end: float end: float
speaker: int = 0
class TranscriptSegment(BaseModel):
text: str
start: float
speaker: int = 0
class Transcript(BaseModel): class Transcript(BaseModel):
text: str = ""
translation: str | None = None translation: str | None = None
words: list[Word] = None words: list[Word] = None
@property
def raw_text(self):
# Uncensored text
return "".join([word.text for word in self.words])
@property
def text(self):
# Censored text
return profanity_filter.censor(self.raw_text).strip()
@property @property
def human_timestamp(self): def human_timestamp(self):
minutes = int(self.timestamp / 60) minutes = int(self.timestamp / 60)
@@ -74,7 +94,6 @@ class Transcript(BaseModel):
self.words = other.words self.words = other.words
else: else:
self.words.extend(other.words) self.words.extend(other.words)
self.text += other.text
def add_offset(self, offset: float): def add_offset(self, offset: float):
for word in self.words: for word in self.words:
@@ -87,6 +106,48 @@ class Transcript(BaseModel):
] ]
return Transcript(text=self.text, translation=self.translation, words=words) return Transcript(text=self.text, translation=self.translation, words=words)
def as_segments(self):
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
last_word = None
BLANK_TIME_SECS = 2
MAX_SEGMENT_LENGTH = 80
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
speaker=word.speaker,
)
continue
is_blank = False
if last_word:
is_blank = word.start - last_word.end > BLANK_TIME_SECS
if (
word.speaker != current_segment.speaker
or (
word.text in ".;:?!…"
and len(current_segment.text) > MAX_SEGMENT_LENGTH
)
or is_blank
):
# check which condition triggered
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
speaker=word.speaker,
)
else:
current_segment.text += word.text
last_word = word
if current_segment:
segments.append(current_segment)
return segments
class TitleSummary(BaseModel): class TitleSummary(BaseModel):
title: str title: str

View File

@@ -49,12 +49,18 @@ class TranscriptText(BaseModel):
translation: str | None translation: str | None
class TranscriptSegmentTopic(BaseModel):
speaker: int
text: str
timestamp: float
class TranscriptTopic(BaseModel): class TranscriptTopic(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
title: str title: str
summary: str summary: str
transcript: str | None = None
timestamp: float timestamp: float
segments: list[TranscriptSegmentTopic] = []
class TranscriptFinalShortSummary(BaseModel): class TranscriptFinalShortSummary(BaseModel):
@@ -523,8 +529,15 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
topic = TranscriptTopic( topic = TranscriptTopic(
title=data.title, title=data.title,
summary=data.summary, summary=data.summary,
transcript=data.transcript.text,
timestamp=data.timestamp, timestamp=data.timestamp,
segments=[
TranscriptSegmentTopic(
speaker=segment.speaker,
text=segment.text,
timestamp=segment.start,
)
for segment in data.transcript.as_segments()
],
) )
resp = transcript.add_event(event=event, data=topic) resp = transcript.add_event(event=event, data=topic)
transcript.upsert_topic(topic) transcript.upsert_topic(topic)