diff --git a/server/poetry.lock b/server/poetry.lock index d0fcd5b3..e608e3b0 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1551,6 +1551,17 @@ files = [ {file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"}, ] +[[package]] +name = "inflection" +version = "0.5.1" +description = "A port of Ruby on Rails inflector to Python" +optional = false +python-versions = ">=3.5" +files = [ + {file = "inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2"}, + {file = "inflection-0.5.1.tar.gz", hash = "sha256:1a29730d366e996aaacffb2f1f1cb9593dc38e2ddd30c91250c6dde09ea9b417"}, +] + [[package]] name = "iniconfig" version = "2.0.0" @@ -2097,6 +2108,20 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "profanityfilter" +version = "2.0.6" +description = "A universal Python library for detecting and/or filtering profane words." +optional = false +python-versions = "*" +files = [ + {file = "profanityfilter-2.0.6-py2.py3-none-any.whl", hash = "sha256:1706c080c2364f5bfe217b2330dc35d90e02e4afa0a00ed52d5673c410b45b64"}, + {file = "profanityfilter-2.0.6.tar.gz", hash = "sha256:ca701e22799526696963415fc36d5e943c168f1917e3c83881ffda6bf5240a30"}, +] + +[package.dependencies] +inflection = "*" + [[package]] name = "prometheus-client" version = "0.17.1" @@ -3744,4 +3769,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b67b094055950a6b39da80dc7ca26b2a0e1c778f174016a00185d7219a3348b5" +content-hash = "a85cb09a0e4b68b29c4272d550e618d2e24ace5f16b707f29e8ac4ce915c1fae" diff --git a/server/pyproject.toml b/server/pyproject.toml index 9e7fb03a..ffe790f2 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -32,6 +32,7 @@ transformers = "^4.32.1" prometheus-fastapi-instrumentator = "^6.1.0" sentencepiece = "^0.1.99" protobuf = "^4.24.3" +profanityfilter = "^2.0.6" [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 63cc1c50..9d6a8558 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -6,11 +6,12 @@ from typing import TypeVar import nltk from prometheus_client import Counter, Histogram +from transformers import GenerationConfig + from reflector.llm.llm_params import TaskParams from reflector.logger import logger as reflector_logger from reflector.settings import settings from reflector.utils.retry import retry -from transformers import GenerationConfig T = TypeVar("T", bound="LLM") @@ -221,6 +222,30 @@ class LLM: return title + def trim_title(self, title: str) -> str: + """ + List of manual trimming to the title. + + Longer titles currently run into + "Discussion on", "Discussion about", etc. that don't really + add any descriptive information and in some cases, this behaviour + can be repeated for several consecutive topics. We want to handle + these cases. + """ + phrases_to_remove = ["Discussion on", "Discussion about"] + try: + pattern = ( + r"\b(?:" + + "|".join(re.escape(phrase) for phrase in phrases_to_remove) + + r")\b" + ) + title = re.sub(pattern, "", title, flags=re.IGNORECASE) + except Exception as e: + reflector_logger.info( + f"Failed to trim {title=} " f"with exception : {str(e)}" + ) + return title + async def _generate( self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs ) -> str: diff --git a/server/reflector/llm/llm_params_cod.py b/server/reflector/llm/llm_params_cod.py new file mode 100644 index 00000000..59eea7c1 --- /dev/null +++ b/server/reflector/llm/llm_params_cod.py @@ -0,0 +1,150 @@ +from typing import Optional, TypeVar + +from pydantic import BaseModel +from transformers import GenerationConfig + + +class TaskParams(BaseModel, arbitrary_types_allowed=True): + instruct: str + gen_cfg: Optional[GenerationConfig] = None + gen_schema: Optional[dict] = None + + +T = TypeVar("T", bound="LLMTaskParams") + + +class LLMTaskParams: + _registry = {} + + @classmethod + def register(cls, task, klass) -> None: + cls._registry[task] = klass + + @classmethod + def get_instance(cls, task: str) -> T: + return cls._registry[task]() + + @property + def task_params(self) -> TaskParams | None: + """ + Fetch the task related parameters + """ + return self._get_task_params() + + def _get_task_params(self) -> None: + pass + + +class FinalLongSummaryParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3 + ) + self._instruct = """ + Take the key ideas and takeaways from the text and create a short + summary. Be sure to keep the length of the response to a minimum. + Do not include trivial information in the summary. + """ + self._schema = { + "type": "object", + "properties": {"long_summary": {"type": "string"}}, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """gen_schema + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +class FinalShortSummaryParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3 + ) + self._instruct = """ + Take the key ideas and takeaways from the text and create a short + summary. Be sure to keep the length of the response to a minimum. + Do not include trivial information in the summary. + """ + self._schema = { + "type": "object", + "properties": {"short_summary": {"type": "string"}}, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """ + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +class FinalTitleParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5 + ) + self._instruct = """ + Combine the following individual titles into one single short title that + condenses the essence of all titles. + """ + self._schema = { + "type": "object", + "properties": {"title": {"type": "string"}}, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """ + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +class TopicParams(LLMTaskParams): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._gen_cfg = GenerationConfig( + max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9 + ) + self._instruct = """ + Create a JSON object as response.The JSON object must have 2 fields: + i) title and ii) summary. + For the title field, generate a very detailed and self-explanatory + title for the given text. Let the title be as descriptive as possible. + For the summary field, summarize the given text in a maximum of + three sentences. + """ + self._schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } + self._task_params = TaskParams( + instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg + ) + + def _get_task_params(self) -> TaskParams: + """ + Return the parameters associated with a specific LLM task + """ + return self._task_params + + +LLMTaskParams.register("topic", TopicParams) +LLMTaskParams.register("final_title", FinalTitleParams) +LLMTaskParams.register("final_short_summary", FinalShortSummaryParams) +LLMTaskParams.register("final_long_summary", FinalLongSummaryParams) diff --git a/server/reflector/processors/audio_transcript.py b/server/reflector/processors/audio_transcript.py index 3f9dc85b..d1882105 100644 --- a/server/reflector/processors/audio_transcript.py +++ b/server/reflector/processors/audio_transcript.py @@ -1,4 +1,6 @@ +from profanityfilter import ProfanityFilter from prometheus_client import Counter, Histogram + from reflector.processors.base import Processor from reflector.processors.types import AudioFile, Transcript @@ -38,6 +40,8 @@ class AudioTranscriptProcessor(Processor): self.m_transcript_call = self.m_transcript_call.labels(name) self.m_transcript_success = self.m_transcript_success.labels(name) self.m_transcript_failure = self.m_transcript_failure.labels(name) + self.profanity_filter = ProfanityFilter() + self.profanity_filter.set_censor("|*|") super().__init__(*args, **kwargs) async def _push(self, data: AudioFile): @@ -56,3 +60,11 @@ class AudioTranscriptProcessor(Processor): async def _transcript(self, data: AudioFile): raise NotImplementedError + + def filter_profanity(self, text: str) -> str: + """ + Remove censored words from the transcript + """ + text = self.profanity_filter.censor(text) + text = text.replace("|*|", "") + return text diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 2ecdc2ec..13e0bd44 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -15,6 +15,7 @@ API will be a POST request to TRANSCRIPT_URL: from time import monotonic import httpx + from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word @@ -86,7 +87,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): if source_language != target_language and target_language in result["text"]: translation = result["text"][target_language] text = result["text"][source_language] - + text = self.filter_profanity(text) transcript = Transcript( text=text, translation=translation, diff --git a/server/reflector/processors/transcript_final_title.py b/server/reflector/processors/transcript_final_title.py index a3360d17..cc05337b 100644 --- a/server/reflector/processors/transcript_final_title.py +++ b/server/reflector/processors/transcript_final_title.py @@ -60,6 +60,8 @@ class TranscriptFinalTitleProcessor(Processor): accumulated_titles = ".".join([chunk.title for chunk in self.chunks]) title_result = await self.get_title(accumulated_titles) + final_title = self.llm.ensure_casing(title_result["title"]) + final_title = self.llm.trim_title(final_title) - final_title = FinalTitle(title=title_result["title"]) + final_title = FinalTitle(title=final_title) await self.emit(final_title) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index dfd2a432..3f7c7105 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -55,8 +55,11 @@ class TranscriptTopicDetectorProcessor(Processor): self.logger.info(f"Topic detector got {len(text)} length transcript") topic_result = await self.get_topic(text=text) + title = self.llm.ensure_casing(topic_result["title"]) + title = self.llm.trim_title(title) + summary = TitleSummary( - title=self.llm.ensure_casing(topic_result["title"]), + title=title, summary=topic_result["summary"], timestamp=self.transcript.timestamp, duration=self.transcript.duration,