diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index df24e58f..b662e05a 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -17,7 +17,7 @@ WHISPER_NUM_WORKERS: int = 1 # Translation Model TRANSLATION_MODEL = "facebook/m2m100_1.2B" -IMAGE_MODEL_DIR = "/root/transcription_models" +IMAGE_MODEL_DIR = f"/root/transcription_models/{TRANSLATION_MODEL}" stub = Stub(name="reflector-transcriber") diff --git a/server/poetry.lock b/server/poetry.lock index d0fcd5b3..e608e3b0 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1551,6 +1551,17 @@ files = [ {file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"}, ] +[[package]] +name = "inflection" +version = "0.5.1" +description = "A port of Ruby on Rails inflector to Python" +optional = false +python-versions = ">=3.5" +files = [ + {file = "inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2"}, + {file = "inflection-0.5.1.tar.gz", hash = "sha256:1a29730d366e996aaacffb2f1f1cb9593dc38e2ddd30c91250c6dde09ea9b417"}, +] + [[package]] name = "iniconfig" version = "2.0.0" @@ -2097,6 +2108,20 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "profanityfilter" +version = "2.0.6" +description = "A universal Python library for detecting and/or filtering profane words." +optional = false +python-versions = "*" +files = [ + {file = "profanityfilter-2.0.6-py2.py3-none-any.whl", hash = "sha256:1706c080c2364f5bfe217b2330dc35d90e02e4afa0a00ed52d5673c410b45b64"}, + {file = "profanityfilter-2.0.6.tar.gz", hash = "sha256:ca701e22799526696963415fc36d5e943c168f1917e3c83881ffda6bf5240a30"}, +] + +[package.dependencies] +inflection = "*" + [[package]] name = "prometheus-client" version = "0.17.1" @@ -3744,4 +3769,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b67b094055950a6b39da80dc7ca26b2a0e1c778f174016a00185d7219a3348b5" +content-hash = "a85cb09a0e4b68b29c4272d550e618d2e24ace5f16b707f29e8ac4ce915c1fae" diff --git a/server/pyproject.toml b/server/pyproject.toml index 9e7fb03a..ffe790f2 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -32,6 +32,7 @@ transformers = "^4.32.1" prometheus-fastapi-instrumentator = "^6.1.0" sentencepiece = "^0.1.99" protobuf = "^4.24.3" +profanityfilter = "^2.0.6" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 63cc1c50..5eb2f15d 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -6,11 +6,12 @@ from typing import TypeVar import nltk from prometheus_client import Counter, Histogram +from transformers import GenerationConfig + from reflector.llm.llm_params import TaskParams from reflector.logger import logger as reflector_logger from reflector.settings import settings from reflector.utils.retry import retry -from transformers import GenerationConfig T = TypeVar("T", bound="LLM") @@ -214,6 +215,9 @@ class LLM: # Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc. pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])" title = re.sub(pattern, "", modified_title) + # Irrespective of casing changes, the starting letter + # of title is always upper-cased + title = title[0].upper() + title[1:] except Exception as e: reflector_logger.info( f"Failed to ensure casing on {title=} " f"with exception : {str(e)}" @@ -221,6 +225,29 @@ class LLM: return title + def trim_title(self, title: str) -> str: + """ + List of manual trimming to the title. + + Longer titles are prone to run into A prefix of phrases that don't + really add any descriptive information and in some cases, this + behaviour can be repeated for several consecutive topics. Trim the + titles to maintain quality of titles. + """ + phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"] + try: + pattern = ( + r"\b(?:" + + "|".join(re.escape(phrase) for phrase in phrases_to_remove) + + r")\b" + ) + title = re.sub(pattern, "", title, flags=re.IGNORECASE) + except Exception as e: + reflector_logger.info( + f"Failed to trim {title=} " f"with exception : {str(e)}" + ) + return title + async def _generate( self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs ) -> str: diff --git a/server/reflector/llm/llm_params.py b/server/reflector/llm/llm_params.py index 59eea7c1..e43956cc 100644 --- a/server/reflector/llm/llm_params.py +++ b/server/reflector/llm/llm_params.py @@ -39,7 +39,7 @@ class FinalLongSummaryParams(LLMTaskParams): def __init__(self, **kwargs): super().__init__(**kwargs) self._gen_cfg = GenerationConfig( - max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3 + max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3 ) self._instruct = """ Take the key ideas and takeaways from the text and create a short @@ -65,7 +65,7 @@ class FinalShortSummaryParams(LLMTaskParams): def __init__(self, **kwargs): super().__init__(**kwargs) self._gen_cfg = GenerationConfig( - max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3 + max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3 ) self._instruct = """ Take the key ideas and takeaways from the text and create a short @@ -116,7 +116,7 @@ class TopicParams(LLMTaskParams): def __init__(self, **kwargs): super().__init__(**kwargs) self._gen_cfg = GenerationConfig( - max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9 + max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9 ) self._instruct = """ Create a JSON object as response.The JSON object must have 2 fields: diff --git a/server/reflector/processors/audio_transcript.py b/server/reflector/processors/audio_transcript.py index 3f9dc85b..f029b587 100644 --- a/server/reflector/processors/audio_transcript.py +++ b/server/reflector/processors/audio_transcript.py @@ -1,4 +1,6 @@ +from profanityfilter import ProfanityFilter from prometheus_client import Counter, Histogram + from reflector.processors.base import Processor from reflector.processors.types import AudioFile, Transcript @@ -38,6 +40,8 @@ class AudioTranscriptProcessor(Processor): self.m_transcript_call = self.m_transcript_call.labels(name) self.m_transcript_success = self.m_transcript_success.labels(name) self.m_transcript_failure = self.m_transcript_failure.labels(name) + self.profanity_filter = ProfanityFilter() + self.profanity_filter.set_censor("*") super().__init__(*args, **kwargs) async def _push(self, data: AudioFile): @@ -56,3 +60,9 @@ class AudioTranscriptProcessor(Processor): async def _transcript(self, data: AudioFile): raise NotImplementedError + + def filter_profanity(self, text: str) -> str: + """ + Remove censored words from the transcript + """ + return self.profanity_filter.censor(text) diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 2ecdc2ec..13e0bd44 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -15,6 +15,7 @@ API will be a POST request to TRANSCRIPT_URL: from time import monotonic import httpx + from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word @@ -86,7 +87,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): if source_language != target_language and target_language in result["text"]: translation = result["text"][target_language] text = result["text"][source_language] - + text = self.filter_profanity(text) transcript = Transcript( text=text, translation=translation, diff --git a/server/reflector/processors/transcript_final_title.py b/server/reflector/processors/transcript_final_title.py index a3360d17..0a8aead8 100644 --- a/server/reflector/processors/transcript_final_title.py +++ b/server/reflector/processors/transcript_final_title.py @@ -60,6 +60,8 @@ class TranscriptFinalTitleProcessor(Processor): accumulated_titles = ".".join([chunk.title for chunk in self.chunks]) title_result = await self.get_title(accumulated_titles) + final_title = self.llm.trim_title(title_result["title"]) + final_title = self.llm.ensure_casing(final_title) - final_title = FinalTitle(title=title_result["title"]) + final_title = FinalTitle(title=final_title) await self.emit(final_title) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index dfd2a432..43bf9762 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -54,9 +54,11 @@ class TranscriptTopicDetectorProcessor(Processor): text = self.transcript.text self.logger.info(f"Topic detector got {len(text)} length transcript") topic_result = await self.get_topic(text=text) + title = self.llm.trim_title(topic_result["title"]) + title = self.llm.ensure_casing(title) summary = TitleSummary( - title=self.llm.ensure_casing(topic_result["title"]), + title=title, summary=topic_result["summary"], timestamp=self.transcript.timestamp, duration=self.transcript.duration, diff --git a/server/tests/conftest.py b/server/tests/conftest.py index d0b3a26f..27417bcb 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -30,7 +30,7 @@ def dummy_processors(): "reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary" ) as mock_short_summary: mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"} - mock_title.return_value = {"title": "LLM FINAL TITLE"} + mock_title.return_value = {"title": "LLM TITLE"} mock_long_summary.return_value = "LLM LONG SUMMARY" mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index d6816192..d691dacb 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -185,7 +185,7 @@ async def test_transcript_rtc_and_websocket( assert "FINAL_TITLE" in eventnames ev = events[eventnames.index("FINAL_TITLE")] - assert ev["data"]["title"] == "LLM FINAL TITLE" + assert ev["data"]["title"] == "LLM TITLE" # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] @@ -330,7 +330,7 @@ async def test_transcript_rtc_and_websocket_and_fr( assert "FINAL_TITLE" in eventnames ev = events[eventnames.index("FINAL_TITLE")] - assert ev["data"]["title"] == "LLM FINAL TITLE" + assert ev["data"]["title"] == "LLM TITLE" # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]