mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
add profanity filter, post-process topic/title
This commit is contained in:
27
server/poetry.lock
generated
27
server/poetry.lock
generated
@@ -1551,6 +1551,17 @@ files = [
|
|||||||
{file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"},
|
{file = "ifaddr-0.2.0.tar.gz", hash = "sha256:cc0cbfcaabf765d44595825fb96a99bb12c79716b73b44330ea38ee2b0c4aed4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "inflection"
|
||||||
|
version = "0.5.1"
|
||||||
|
description = "A port of Ruby on Rails inflector to Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.5"
|
||||||
|
files = [
|
||||||
|
{file = "inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2"},
|
||||||
|
{file = "inflection-0.5.1.tar.gz", hash = "sha256:1a29730d366e996aaacffb2f1f1cb9593dc38e2ddd30c91250c6dde09ea9b417"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "iniconfig"
|
name = "iniconfig"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@@ -2097,6 +2108,20 @@ files = [
|
|||||||
dev = ["pre-commit", "tox"]
|
dev = ["pre-commit", "tox"]
|
||||||
testing = ["pytest", "pytest-benchmark"]
|
testing = ["pytest", "pytest-benchmark"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "profanityfilter"
|
||||||
|
version = "2.0.6"
|
||||||
|
description = "A universal Python library for detecting and/or filtering profane words."
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "profanityfilter-2.0.6-py2.py3-none-any.whl", hash = "sha256:1706c080c2364f5bfe217b2330dc35d90e02e4afa0a00ed52d5673c410b45b64"},
|
||||||
|
{file = "profanityfilter-2.0.6.tar.gz", hash = "sha256:ca701e22799526696963415fc36d5e943c168f1917e3c83881ffda6bf5240a30"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
inflection = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prometheus-client"
|
name = "prometheus-client"
|
||||||
version = "0.17.1"
|
version = "0.17.1"
|
||||||
@@ -3744,4 +3769,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "b67b094055950a6b39da80dc7ca26b2a0e1c778f174016a00185d7219a3348b5"
|
content-hash = "a85cb09a0e4b68b29c4272d550e618d2e24ace5f16b707f29e8ac4ce915c1fae"
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ transformers = "^4.32.1"
|
|||||||
prometheus-fastapi-instrumentator = "^6.1.0"
|
prometheus-fastapi-instrumentator = "^6.1.0"
|
||||||
sentencepiece = "^0.1.99"
|
sentencepiece = "^0.1.99"
|
||||||
protobuf = "^4.24.3"
|
protobuf = "^4.24.3"
|
||||||
|
profanityfilter = "^2.0.6"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ from typing import TypeVar
|
|||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
from prometheus_client import Counter, Histogram
|
from prometheus_client import Counter, Histogram
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
from reflector.llm.llm_params import TaskParams
|
from reflector.llm.llm_params import TaskParams
|
||||||
from reflector.logger import logger as reflector_logger
|
from reflector.logger import logger as reflector_logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.retry import retry
|
from reflector.utils.retry import retry
|
||||||
from transformers import GenerationConfig
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="LLM")
|
T = TypeVar("T", bound="LLM")
|
||||||
|
|
||||||
@@ -221,6 +222,30 @@ class LLM:
|
|||||||
|
|
||||||
return title
|
return title
|
||||||
|
|
||||||
|
def trim_title(self, title: str) -> str:
|
||||||
|
"""
|
||||||
|
List of manual trimming to the title.
|
||||||
|
|
||||||
|
Longer titles currently run into
|
||||||
|
"Discussion on", "Discussion about", etc. that don't really
|
||||||
|
add any descriptive information and in some cases, this behaviour
|
||||||
|
can be repeated for several consecutive topics. We want to handle
|
||||||
|
these cases.
|
||||||
|
"""
|
||||||
|
phrases_to_remove = ["Discussion on", "Discussion about"]
|
||||||
|
try:
|
||||||
|
pattern = (
|
||||||
|
r"\b(?:"
|
||||||
|
+ "|".join(re.escape(phrase) for phrase in phrases_to_remove)
|
||||||
|
+ r")\b"
|
||||||
|
)
|
||||||
|
title = re.sub(pattern, "", title, flags=re.IGNORECASE)
|
||||||
|
except Exception as e:
|
||||||
|
reflector_logger.info(
|
||||||
|
f"Failed to trim {title=} " f"with exception : {str(e)}"
|
||||||
|
)
|
||||||
|
return title
|
||||||
|
|
||||||
async def _generate(
|
async def _generate(
|
||||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|||||||
150
server/reflector/llm/llm_params_cod.py
Normal file
150
server/reflector/llm/llm_params_cod.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
from typing import Optional, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
||||||
|
instruct: str
|
||||||
|
gen_cfg: Optional[GenerationConfig] = None
|
||||||
|
gen_schema: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="LLMTaskParams")
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTaskParams:
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, task, klass) -> None:
|
||||||
|
cls._registry[task] = klass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls, task: str) -> T:
|
||||||
|
return cls._registry[task]()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_params(self) -> TaskParams | None:
|
||||||
|
"""
|
||||||
|
Fetch the task related parameters
|
||||||
|
"""
|
||||||
|
return self._get_task_params()
|
||||||
|
|
||||||
|
def _get_task_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLongSummaryParams(LLMTaskParams):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._gen_cfg = GenerationConfig(
|
||||||
|
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
||||||
|
)
|
||||||
|
self._instruct = """
|
||||||
|
Take the key ideas and takeaways from the text and create a short
|
||||||
|
summary. Be sure to keep the length of the response to a minimum.
|
||||||
|
Do not include trivial information in the summary.
|
||||||
|
"""
|
||||||
|
self._schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"long_summary": {"type": "string"}},
|
||||||
|
}
|
||||||
|
self._task_params = TaskParams(
|
||||||
|
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_task_params(self) -> TaskParams:
|
||||||
|
"""gen_schema
|
||||||
|
Return the parameters associated with a specific LLM task
|
||||||
|
"""
|
||||||
|
return self._task_params
|
||||||
|
|
||||||
|
|
||||||
|
class FinalShortSummaryParams(LLMTaskParams):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._gen_cfg = GenerationConfig(
|
||||||
|
max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3
|
||||||
|
)
|
||||||
|
self._instruct = """
|
||||||
|
Take the key ideas and takeaways from the text and create a short
|
||||||
|
summary. Be sure to keep the length of the response to a minimum.
|
||||||
|
Do not include trivial information in the summary.
|
||||||
|
"""
|
||||||
|
self._schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"short_summary": {"type": "string"}},
|
||||||
|
}
|
||||||
|
self._task_params = TaskParams(
|
||||||
|
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_task_params(self) -> TaskParams:
|
||||||
|
"""
|
||||||
|
Return the parameters associated with a specific LLM task
|
||||||
|
"""
|
||||||
|
return self._task_params
|
||||||
|
|
||||||
|
|
||||||
|
class FinalTitleParams(LLMTaskParams):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._gen_cfg = GenerationConfig(
|
||||||
|
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
||||||
|
)
|
||||||
|
self._instruct = """
|
||||||
|
Combine the following individual titles into one single short title that
|
||||||
|
condenses the essence of all titles.
|
||||||
|
"""
|
||||||
|
self._schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"title": {"type": "string"}},
|
||||||
|
}
|
||||||
|
self._task_params = TaskParams(
|
||||||
|
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_task_params(self) -> TaskParams:
|
||||||
|
"""
|
||||||
|
Return the parameters associated with a specific LLM task
|
||||||
|
"""
|
||||||
|
return self._task_params
|
||||||
|
|
||||||
|
|
||||||
|
class TopicParams(LLMTaskParams):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._gen_cfg = GenerationConfig(
|
||||||
|
max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9
|
||||||
|
)
|
||||||
|
self._instruct = """
|
||||||
|
Create a JSON object as response.The JSON object must have 2 fields:
|
||||||
|
i) title and ii) summary.
|
||||||
|
For the title field, generate a very detailed and self-explanatory
|
||||||
|
title for the given text. Let the title be as descriptive as possible.
|
||||||
|
For the summary field, summarize the given text in a maximum of
|
||||||
|
three sentences.
|
||||||
|
"""
|
||||||
|
self._schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"summary": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self._task_params = TaskParams(
|
||||||
|
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_task_params(self) -> TaskParams:
|
||||||
|
"""
|
||||||
|
Return the parameters associated with a specific LLM task
|
||||||
|
"""
|
||||||
|
return self._task_params
|
||||||
|
|
||||||
|
|
||||||
|
LLMTaskParams.register("topic", TopicParams)
|
||||||
|
LLMTaskParams.register("final_title", FinalTitleParams)
|
||||||
|
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
||||||
|
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
|
from profanityfilter import ProfanityFilter
|
||||||
from prometheus_client import Counter, Histogram
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
from reflector.processors.base import Processor
|
from reflector.processors.base import Processor
|
||||||
from reflector.processors.types import AudioFile, Transcript
|
from reflector.processors.types import AudioFile, Transcript
|
||||||
|
|
||||||
@@ -38,6 +40,8 @@ class AudioTranscriptProcessor(Processor):
|
|||||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||||
|
self.profanity_filter = ProfanityFilter()
|
||||||
|
self.profanity_filter.set_censor("|*|")
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
async def _push(self, data: AudioFile):
|
async def _push(self, data: AudioFile):
|
||||||
@@ -56,3 +60,11 @@ class AudioTranscriptProcessor(Processor):
|
|||||||
|
|
||||||
async def _transcript(self, data: AudioFile):
|
async def _transcript(self, data: AudioFile):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def filter_profanity(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove censored words from the transcript
|
||||||
|
"""
|
||||||
|
text = self.profanity_filter.censor(text)
|
||||||
|
text = text.replace("|*|", "")
|
||||||
|
return text
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ API will be a POST request to TRANSCRIPT_URL:
|
|||||||
from time import monotonic
|
from time import monotonic
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||||
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
|
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
|
||||||
@@ -86,7 +87,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
if source_language != target_language and target_language in result["text"]:
|
if source_language != target_language and target_language in result["text"]:
|
||||||
translation = result["text"][target_language]
|
translation = result["text"][target_language]
|
||||||
text = result["text"][source_language]
|
text = result["text"][source_language]
|
||||||
|
text = self.filter_profanity(text)
|
||||||
transcript = Transcript(
|
transcript = Transcript(
|
||||||
text=text,
|
text=text,
|
||||||
translation=translation,
|
translation=translation,
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ class TranscriptFinalTitleProcessor(Processor):
|
|||||||
|
|
||||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
||||||
title_result = await self.get_title(accumulated_titles)
|
title_result = await self.get_title(accumulated_titles)
|
||||||
|
final_title = self.llm.ensure_casing(title_result["title"])
|
||||||
|
final_title = self.llm.trim_title(final_title)
|
||||||
|
|
||||||
final_title = FinalTitle(title=title_result["title"])
|
final_title = FinalTitle(title=final_title)
|
||||||
await self.emit(final_title)
|
await self.emit(final_title)
|
||||||
|
|||||||
@@ -55,8 +55,11 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||||
topic_result = await self.get_topic(text=text)
|
topic_result = await self.get_topic(text=text)
|
||||||
|
|
||||||
|
title = self.llm.ensure_casing(topic_result["title"])
|
||||||
|
title = self.llm.trim_title(title)
|
||||||
|
|
||||||
summary = TitleSummary(
|
summary = TitleSummary(
|
||||||
title=self.llm.ensure_casing(topic_result["title"]),
|
title=title,
|
||||||
summary=topic_result["summary"],
|
summary=topic_result["summary"],
|
||||||
timestamp=self.transcript.timestamp,
|
timestamp=self.transcript.timestamp,
|
||||||
duration=self.transcript.duration,
|
duration=self.transcript.duration,
|
||||||
|
|||||||
Reference in New Issue
Block a user