mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
feat: new summary using phi-4 and llama-index (#519)
* feat: add litellm backend implementation * refactor: improve generate/completion methods for base LLM * refactor: remove tokenizer logic * style: apply code formatting * fix: remove hallucinations from LLM responses * refactor: comprehensive LLM and summarization rework * chore: remove debug code * feat: add structured output support to LiteLLM * refactor: apply self-review improvements * docs: add model structured output comments * docs: update model structured output comments * style: apply linting and formatting fixes * fix: resolve type logic bug * refactor: apply PR review feedback * refactor: apply additional PR review feedback * refactor: apply final PR review feedback * fix: improve schema passing for LLMs without structured output * feat: add PR comments and logger improvements * docs: update README and add HTTP logging * feat: improve HTTP logging * feat: add summary chunking functionality * fix: resolve title generation runtime issues * refactor: apply self-review improvements * style: apply linting and formatting * feat: implement LiteLLM class structure * style: apply linting and formatting fixes * docs: env template model name fix * chore: remove older litellm class * chore: format * refactor: simplify OpenAILLM * refactor: OpenAILLM tokenizer * refactor: self-review * refactor: self-review * refactor: self-review * chore: format * chore: remove LLM_USE_STRUCTURED_OUTPUT from envs * chore: roll back migration lint changes * chore: roll back migration lint changes * fix: make summary llm configuration optional for the tests * fix: missing f-string * fix: tweak the prompt for summary title * feat: try llamaindex for summarization * fix: complete refactor of summary builder using llamaindex and structured output when possible * fix: separate prompt as constant * fix: typings * fix: enhance prompt to prevent mentioning others subject while summarize one * fix: various changes after self-review * fix: from igor review --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
@@ -51,7 +51,7 @@ async def main() -> NoReturn:
|
||||
|
||||
logger.info(f"Cancelling {len(tasks)} outstanding tasks")
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info(f'{"Flushing metrics"}')
|
||||
logger.info(f"{'Flushing metrics'}")
|
||||
loop.stop()
|
||||
|
||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||
|
||||
@@ -17,6 +17,7 @@ T = TypeVar("T", bound="LLM")
|
||||
class LLM:
|
||||
_nltk_downloaded = False
|
||||
_registry = {}
|
||||
model_name: str
|
||||
m_generate = Histogram(
|
||||
"llm_generate",
|
||||
"Time spent in LLM.generate",
|
||||
@@ -69,6 +70,7 @@ class LLM:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
cls.ensure_nltk()
|
||||
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
@@ -121,6 +123,11 @@ class LLM:
|
||||
def _get_tokenizer(self):
|
||||
pass
|
||||
|
||||
def has_structured_output(self):
|
||||
# whether implementation supports structured output
|
||||
# on the model side (otherwise it's prompt engineering)
|
||||
return False
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -140,6 +147,7 @@ class LLM:
|
||||
prompt=prompt,
|
||||
gen_schema=gen_schema,
|
||||
gen_cfg=gen_cfg,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
@@ -167,7 +175,9 @@ class LLM:
|
||||
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._completion)(messages=messages, **kwargs)
|
||||
result = await retry(self._completion)(
|
||||
messages=messages, **{**kwargs, "logger": logger}
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
@@ -253,9 +263,7 @@ class LLM:
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _completion(
|
||||
self, messages: list, logger: reflector_logger, **kwargs
|
||||
) -> dict:
|
||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
|
||||
@@ -31,7 +31,7 @@ class ModalLLM(LLM):
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
) -> str:
|
||||
json_payload = {"prompt": prompt}
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
@@ -52,12 +52,14 @@ class ModalLLM(LLM):
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
logger=kwargs.get("logger", reflector_logger),
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
return text
|
||||
|
||||
async def _completion(self, messages: list, **kwargs) -> dict:
|
||||
# returns full api response
|
||||
kwargs.setdefault("temperature", 0.3)
|
||||
kwargs.setdefault("max_tokens", 2048)
|
||||
kwargs.setdefault("stream", False)
|
||||
@@ -78,6 +80,7 @@ class ModalLLM(LLM):
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
follow_redirects=True,
|
||||
logger=kwargs.get("logger", reflector_logger),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
117
server/reflector/llm/openai_llm.py
Normal file
117
server/reflector/llm/openai_llm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import httpx
|
||||
from transformers import AutoTokenizer
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
def apply_gen_config(payload: dict, gen_cfg) -> None:
|
||||
"""Apply generation config overrides to the payload."""
|
||||
config_mapping = {
|
||||
"temperature": "temperature",
|
||||
"max_new_tokens": "max_tokens",
|
||||
"max_tokens": "max_tokens",
|
||||
"top_p": "top_p",
|
||||
"frequency_penalty": "frequency_penalty",
|
||||
"presence_penalty": "presence_penalty",
|
||||
}
|
||||
|
||||
for cfg_attr, payload_key in config_mapping.items():
|
||||
value = getattr(gen_cfg, cfg_attr, None)
|
||||
if value is not None:
|
||||
payload[payload_key] = value
|
||||
if cfg_attr == "max_new_tokens": # Handle max_new_tokens taking precedence
|
||||
break
|
||||
|
||||
|
||||
class OpenAILLM:
|
||||
def __init__(self, config_prefix: str, settings):
|
||||
self.config_prefix = config_prefix
|
||||
self.settings_obj = settings
|
||||
self.model_name = getattr(settings, f"{config_prefix}_MODEL")
|
||||
self.url = getattr(settings, f"{config_prefix}_LLM_URL")
|
||||
self.api_key = getattr(settings, f"{config_prefix}_LLM_API_KEY")
|
||||
|
||||
timeout = getattr(settings, f"{config_prefix}_LLM_TIMEOUT", 300)
|
||||
self.temperature = getattr(settings, f"{config_prefix}_LLM_TEMPERATURE", 0.7)
|
||||
self.max_tokens = getattr(settings, f"{config_prefix}_LLM_MAX_TOKENS", 1024)
|
||||
self.client = httpx.AsyncClient(timeout=timeout)
|
||||
|
||||
# Use a tokenizer that approximates OpenAI token counting
|
||||
tokenizer_name = getattr(settings, f"{config_prefix}_TOKENIZER", "gpt2")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to load tokenizer '{tokenizer_name}', falling back to default 'gpt2' tokenizer"
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, gen_schema=None, gen_cfg=None, logger=None
|
||||
) -> str:
|
||||
if logger:
|
||||
logger.debug(
|
||||
"OpenAI LLM generate",
|
||||
prompt=repr(prompt[:100] + "..." if len(prompt) > 100 else prompt),
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
result = await self.completion(
|
||||
messages, gen_schema=gen_schema, gen_cfg=gen_cfg, logger=logger
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def completion(
|
||||
self, messages: list, gen_schema=None, gen_cfg=None, logger=None, **kwargs
|
||||
) -> dict:
|
||||
if logger:
|
||||
logger.info("OpenAI LLM completion", messages_count=len(messages))
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
# Apply generation config overrides
|
||||
if gen_cfg:
|
||||
apply_gen_config(payload, gen_cfg)
|
||||
|
||||
# Apply structured output schema
|
||||
if gen_schema:
|
||||
payload["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "response", "schema": gen_schema},
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
url = f"{self.url.rstrip('/')}/chat/completions"
|
||||
|
||||
if logger:
|
||||
logger.debug(
|
||||
"OpenAI API request", url=url, payload_keys=list(payload.keys())
|
||||
)
|
||||
|
||||
response = await self.client.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if logger:
|
||||
logger.debug(
|
||||
"OpenAI API response",
|
||||
status_code=response.status_code,
|
||||
choices_count=len(result.get("choices", [])),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.client.aclose()
|
||||
@@ -16,7 +16,7 @@ import functools
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import boto3
|
||||
from celery import chord, group, shared_task
|
||||
from celery import chord, group, shared_task, current_task
|
||||
from pydantic import BaseModel
|
||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||
from reflector.db.recordings import recordings_controller
|
||||
@@ -111,16 +111,29 @@ def get_transcript(func):
|
||||
Decorator to fetch the transcript from the database from the first argument
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(**kwargs):
|
||||
transcript_id = kwargs.pop("transcript_id")
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||
if not transcript:
|
||||
raise Exception("Transcript {transcript_id} not found")
|
||||
|
||||
# Enhanced logger with Celery task context
|
||||
tlogger = logger.bind(transcript_id=transcript.id)
|
||||
if current_task:
|
||||
tlogger = tlogger.bind(
|
||||
task_id=current_task.request.id,
|
||||
task_name=current_task.name,
|
||||
worker_hostname=current_task.request.hostname,
|
||||
task_retries=current_task.request.retries,
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
|
||||
try:
|
||||
return await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
result = await func(transcript=transcript, logger=tlogger, **kwargs)
|
||||
return result
|
||||
except Exception as exc:
|
||||
tlogger.error("Pipeline error", exc_info=exc)
|
||||
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,8 @@
|
||||
from reflector.llm import LLM
|
||||
from reflector.llm.openai_llm import OpenAILLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.summary.summary_builder import SummaryBuilder
|
||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class TranscriptFinalSummaryProcessor(Processor):
|
||||
@@ -16,14 +17,14 @@ class TranscriptFinalSummaryProcessor(Processor):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = transcript
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance(model_name="NousResearch/Hermes-3-Llama-3.1-8B")
|
||||
self.llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
|
||||
self.builder = None
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_summary_builder(self, text) -> SummaryBuilder:
|
||||
builder = SummaryBuilder(self.llm)
|
||||
builder = SummaryBuilder(self.llm, logger=self.logger)
|
||||
builder.set_transcript(text)
|
||||
await builder.identify_participants()
|
||||
await builder.generate_summary()
|
||||
|
||||
@@ -49,7 +49,7 @@ class TranscriptFinalTitleProcessor(Processor):
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_titles += title_result["summary"]
|
||||
accumulated_titles += title_result["title"]
|
||||
|
||||
return await self.get_title(accumulated_titles)
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ class TranscriptTranslatorProcessor(Processor):
|
||||
params=json_payload,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
logger=self.logger,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()["text"]
|
||||
|
||||
@@ -82,6 +82,12 @@ class Settings(BaseSettings):
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# per-task cases
|
||||
SUMMARY_MODEL: str = "monadical/private/smart"
|
||||
SUMMARY_LLM_URL: str | None = None
|
||||
SUMMARY_LLM_API_KEY: str | None = None
|
||||
SUMMARY_LLM_CONTEXT_SIZE_TOKENS: int = 16000
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
|
||||
@@ -126,7 +126,7 @@ class StreamClient:
|
||||
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
self.reader = self.worker(f'{"worker"}', self.queue)
|
||||
self.reader = self.worker(f"{'worker'}", self.queue)
|
||||
|
||||
def get_reader(self):
|
||||
return self.reader
|
||||
|
||||
@@ -36,9 +36,13 @@ async def export_db(filename: str) -> None:
|
||||
if entry["event"] == "TRANSCRIPT":
|
||||
yield tid, "event_transcript", idx, "text", entry["data"]["text"]
|
||||
if entry["data"].get("translation") is not None:
|
||||
yield tid, "event_transcript", idx, "translation", entry[
|
||||
"data"
|
||||
].get("translation", None)
|
||||
yield (
|
||||
tid,
|
||||
"event_transcript",
|
||||
idx,
|
||||
"translation",
|
||||
entry["data"].get("translation", None),
|
||||
)
|
||||
|
||||
def export_transcripts(transcripts):
|
||||
for transcript in transcripts:
|
||||
|
||||
@@ -34,6 +34,7 @@ def retry(fn):
|
||||
),
|
||||
)
|
||||
retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", (Exception,))
|
||||
retry_logger = kwargs.pop("logger", logger)
|
||||
|
||||
result = None
|
||||
last_exception = None
|
||||
@@ -58,17 +59,33 @@ def retry(fn):
|
||||
if result:
|
||||
return result
|
||||
except HTTPStatusError as e:
|
||||
logger.exception(e)
|
||||
retry_logger.exception(e)
|
||||
status_code = e.response.status_code
|
||||
logger.debug(f"HTTP status {status_code} - {e}")
|
||||
|
||||
# Log detailed error information including response body
|
||||
try:
|
||||
response_text = e.response.text
|
||||
response_headers = dict(e.response.headers)
|
||||
retry_logger.error(
|
||||
f"HTTP {status_code} error for {e.request.method} {e.request.url}\n"
|
||||
f"Response headers: {response_headers}\n"
|
||||
f"Response body: {response_text}"
|
||||
)
|
||||
|
||||
except Exception as log_error:
|
||||
retry_logger.warning(
|
||||
f"Failed to log detailed error info: {log_error}"
|
||||
)
|
||||
retry_logger.debug(f"HTTP status {status_code} - {e}")
|
||||
|
||||
if status_code in retry_httpx_status_stop:
|
||||
message = f"HTTP status {status_code} is in retry_httpx_status_stop"
|
||||
raise RetryHTTPException(message) from e
|
||||
except retry_ignore_exc_types as e:
|
||||
logger.exception(e)
|
||||
retry_logger.exception(e)
|
||||
last_exception = e
|
||||
|
||||
logger.debug(
|
||||
retry_logger.debug(
|
||||
f"Retrying {fn_name} - in {retry_backoff_interval:.1f}s "
|
||||
f"({monotonic() - start:.1f}s / {retry_timeout:.1f}s)"
|
||||
)
|
||||
|
||||
@@ -253,9 +253,7 @@ def summarize(
|
||||
LOGGER.info("Breaking transcript into smaller chunks")
|
||||
chunks = chunk_text(transcript_text)
|
||||
|
||||
LOGGER.info(
|
||||
f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
|
||||
)
|
||||
LOGGER.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words")
|
||||
|
||||
LOGGER.info(f"Writing summary text to: {output_file}")
|
||||
with open(output_file, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user