diff --git a/server/env.example b/server/env.example
index 10079105..4952a937 100644
--- a/server/env.example
+++ b/server/env.example
@@ -46,38 +46,11 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
## llm backend implementation
## =======================================================
-## Using serverless modal.com (require reflector-gpu-modal deployed)
-LLM_BACKEND=modal
-LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
-LLM_MODAL_API_KEY=
-ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
-
-
-## Using OpenAI
-#LLM_BACKEND=openai
-#LLM_OPENAI_KEY=xxx
-#LLM_OPENAI_MODEL=gpt-3.5-turbo
-
-## Using GPT4ALL
-#LLM_BACKEND=openai
-#LLM_URL=http://localhost:4891/v1/completions
-#LLM_OPENAI_MODEL="GPT4All Falcon"
-
-## Default LLM MODEL NAME
-#DEFAULT_LLM=lmsys/vicuna-13b-v1.5
-
-## Cache directory to store models
-CACHE_DIR=data
-
-## =======================================================
-## Summary LLM configuration
-## =======================================================
-
## Context size for summary generation (tokens)
-SUMMARY_LLM_CONTEXT_SIZE_TOKENS=16000
-SUMMARY_LLM_URL=
-SUMMARY_LLM_API_KEY=sk-
-SUMMARY_MODEL=
+# LLM_MODEL=microsoft/phi-4
+LLM_CONTEXT_WINDOW=16000
+LLM_URL=
+LLM_API_KEY=sk-
## =======================================================
## Diarization
diff --git a/server/gpu/modal_deployments/README.md b/server/gpu/modal_deployments/README.md
index dee4052e..f31810e1 100644
--- a/server/gpu/modal_deployments/README.md
+++ b/server/gpu/modal_deployments/README.md
@@ -3,8 +3,9 @@
This repository hold an API for the GPU implementation of the Reflector API service,
and use [Modal.com](https://modal.com)
-- `reflector_llm.py` - LLM API
+- `reflector_diarizer.py` - Diarization API
- `reflector_transcriber.py` - Transcription API
+- `reflector_translator.py` - Translation API
## Modal.com deployment
diff --git a/server/gpu/modal_deployments/reflector_llm.py b/server/gpu/modal_deployments/reflector_llm.py
deleted file mode 100644
index f3752f5d..00000000
--- a/server/gpu/modal_deployments/reflector_llm.py
+++ /dev/null
@@ -1,213 +0,0 @@
-"""
-Reflector GPU backend - LLM
-===========================
-
-"""
-
-import json
-import os
-import threading
-from typing import Optional
-
-from modal import App, Image, Secret, asgi_app, enter, exit, method
-
-# LLM
-LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
-LLM_LOW_CPU_MEM_USAGE: bool = True
-LLM_TORCH_DTYPE: str = "bfloat16"
-LLM_MAX_NEW_TOKENS: int = 300
-
-IMAGE_MODEL_DIR = "/root/llm_models"
-
-app = App(name="reflector-llm")
-
-
-def download_llm():
- from huggingface_hub import snapshot_download
-
- print("Downloading LLM model")
- snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
- print("LLM model downloaded")
-
-
-def migrate_cache_llm():
- """
- XXX The cache for model files in Transformers v4.22.0 has been updated.
- Migrating your old cache. This is a one-time only operation. You can
- interrupt this and resume the migration later on by calling
- `transformers.utils.move_cache()`.
- """
- from transformers.utils.hub import move_cache
-
- print("Moving LLM cache")
- move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
- print("LLM cache moved")
-
-
-llm_image = (
- Image.debian_slim(python_version="3.10.8")
- .apt_install("git")
- .pip_install(
- "transformers",
- "torch",
- "sentencepiece",
- "protobuf",
- "jsonformer==0.12.0",
- "accelerate==0.21.0",
- "einops==0.6.1",
- "hf-transfer~=0.1",
- "huggingface_hub==0.16.4",
- )
- .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
- .run_function(download_llm)
- .run_function(migrate_cache_llm)
-)
-
-
-@app.cls(
- gpu="A100",
- timeout=60 * 5,
- scaledown_window=60 * 5,
- allow_concurrent_inputs=15,
- image=llm_image,
-)
-class LLM:
- @enter()
- def enter(self):
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
-
- print("Instance llm model")
- model = AutoModelForCausalLM.from_pretrained(
- LLM_MODEL,
- torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
- low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
- cache_dir=IMAGE_MODEL_DIR,
- local_files_only=True,
- )
-
- # JSONFormer doesn't yet support generation configs
- print("Instance llm generation config")
- model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
-
- # generation configuration
- gen_cfg = GenerationConfig.from_model_config(model.config)
- gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
-
- # load tokenizer
- print("Instance llm tokenizer")
- tokenizer = AutoTokenizer.from_pretrained(
- LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
- )
-
- # move model to gpu
- print("Move llm model to GPU")
- model = model.cuda()
-
- print("Warmup llm done")
- self.model = model
- self.tokenizer = tokenizer
- self.gen_cfg = gen_cfg
- self.GenerationConfig = GenerationConfig
-
- self.lock = threading.Lock()
-
- @exit()
- def exit():
- print("Exit llm")
-
- @method()
- def generate(
- self, prompt: str, gen_schema: str | None, gen_cfg: str | None
- ) -> dict:
- """
- Perform a generation action using the LLM
- """
- print(f"Generate {prompt=}")
- if gen_cfg:
- gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
- else:
- gen_cfg = self.gen_cfg
-
- # If a gen_schema is given, conform to gen_schema
- with self.lock:
- if gen_schema:
- import jsonformer
-
- print(f"Schema {gen_schema=}")
- jsonformer_llm = jsonformer.Jsonformer(
- model=self.model,
- tokenizer=self.tokenizer,
- json_schema=json.loads(gen_schema),
- prompt=prompt,
- max_string_token_length=gen_cfg.max_new_tokens,
- )
- response = jsonformer_llm()
- else:
- # If no gen_schema, perform prompt only generation
-
- # tokenize prompt
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
- self.model.device
- )
- output = self.model.generate(input_ids, generation_config=gen_cfg)
-
- # decode output
- response = self.tokenizer.decode(
- output[0].cpu(), skip_special_tokens=True
- )
- response = response[len(prompt) :]
- print(f"Generated {response=}")
- return {"text": response}
-
-
-# -------------------------------------------------------------------
-# Web API
-# -------------------------------------------------------------------
-
-
-@app.function(
- scaledown_window=60 * 10,
- timeout=60 * 5,
- allow_concurrent_inputs=45,
- secrets=[
- Secret.from_name("reflector-gpu"),
- ],
-)
-@asgi_app()
-def web():
- from fastapi import Depends, FastAPI, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from pydantic import BaseModel
-
- llmstub = LLM()
-
- app = FastAPI()
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
-
- def apikey_auth(apikey: str = Depends(oauth2_scheme)):
- if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid API key",
- headers={"WWW-Authenticate": "Bearer"},
- )
-
- class LLMRequest(BaseModel):
- prompt: str
- gen_schema: Optional[dict] = None
- gen_cfg: Optional[dict] = None
-
- @app.post("/llm", dependencies=[Depends(apikey_auth)])
- def llm(
- req: LLMRequest,
- ):
- gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
- gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
- func = llmstub.generate.spawn(
- prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
- )
- result = func.get()
- return result
-
- return app
diff --git a/server/gpu/modal_deployments/reflector_llm_zephyr.py b/server/gpu/modal_deployments/reflector_llm_zephyr.py
deleted file mode 100644
index 5d9c0390..00000000
--- a/server/gpu/modal_deployments/reflector_llm_zephyr.py
+++ /dev/null
@@ -1,219 +0,0 @@
-"""
-Reflector GPU backend - LLM
-===========================
-
-"""
-
-import json
-import os
-import threading
-from typing import Optional
-
-from modal import App, Image, Secret, asgi_app, enter, exit, method
-
-# LLM
-LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
-LLM_LOW_CPU_MEM_USAGE: bool = True
-LLM_TORCH_DTYPE: str = "bfloat16"
-LLM_MAX_NEW_TOKENS: int = 300
-
-IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
-
-app = App(name="reflector-llm-zephyr")
-
-
-def download_llm():
- from huggingface_hub import snapshot_download
-
- print("Downloading LLM model")
- snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
- print("LLM model downloaded")
-
-
-def migrate_cache_llm():
- """
- XXX The cache for model files in Transformers v4.22.0 has been updated.
- Migrating your old cache. This is a one-time only operation. You can
- interrupt this and resume the migration later on by calling
- `transformers.utils.move_cache()`.
- """
- from transformers.utils.hub import move_cache
-
- print("Moving LLM cache")
- move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
- print("LLM cache moved")
-
-
-llm_image = (
- Image.debian_slim(python_version="3.10.8")
- .apt_install("git")
- .pip_install(
- "transformers==4.34.0",
- "torch",
- "sentencepiece",
- "protobuf",
- "jsonformer==0.12.0",
- "accelerate==0.21.0",
- "einops==0.6.1",
- "hf-transfer~=0.1",
- "huggingface_hub==0.16.4",
- )
- .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
- .run_function(download_llm)
- .run_function(migrate_cache_llm)
-)
-
-
-@app.cls(
- gpu="A10G",
- timeout=60 * 5,
- scaledown_window=60 * 5,
- allow_concurrent_inputs=10,
- image=llm_image,
-)
-class LLM:
- @enter()
- def enter(self):
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
-
- print("Instance llm model")
- model = AutoModelForCausalLM.from_pretrained(
- LLM_MODEL,
- torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
- low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
- cache_dir=IMAGE_MODEL_DIR,
- local_files_only=True,
- )
-
- # JSONFormer doesn't yet support generation configs
- print("Instance llm generation config")
- model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
-
- # generation configuration
- gen_cfg = GenerationConfig.from_model_config(model.config)
- gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
-
- # load tokenizer
- print("Instance llm tokenizer")
- tokenizer = AutoTokenizer.from_pretrained(
- LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
- )
- gen_cfg.pad_token_id = tokenizer.eos_token_id
- gen_cfg.eos_token_id = tokenizer.eos_token_id
- tokenizer.pad_token = tokenizer.eos_token
- model.config.pad_token_id = tokenizer.eos_token_id
-
- # move model to gpu
- print("Move llm model to GPU")
- model = model.cuda()
-
- print("Warmup llm done")
- self.model = model
- self.tokenizer = tokenizer
- self.gen_cfg = gen_cfg
- self.GenerationConfig = GenerationConfig
- self.lock = threading.Lock()
-
- @exit()
- def exit():
- print("Exit llm")
-
- @method()
- def generate(
- self, prompt: str, gen_schema: str | None, gen_cfg: str | None
- ) -> dict:
- """
- Perform a generation action using the LLM
- """
- print(f"Generate {prompt=}")
- if gen_cfg:
- gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
- gen_cfg.pad_token_id = self.tokenizer.eos_token_id
- gen_cfg.eos_token_id = self.tokenizer.eos_token_id
- else:
- gen_cfg = self.gen_cfg
-
- # If a gen_schema is given, conform to gen_schema
- with self.lock:
- if gen_schema:
- import jsonformer
-
- print(f"Schema {gen_schema=}")
- jsonformer_llm = jsonformer.Jsonformer(
- model=self.model,
- tokenizer=self.tokenizer,
- json_schema=json.loads(gen_schema),
- prompt=prompt,
- max_string_token_length=gen_cfg.max_new_tokens,
- )
- response = jsonformer_llm()
- else:
- # If no gen_schema, perform prompt only generation
-
- # tokenize prompt
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
- self.model.device
- )
- output = self.model.generate(input_ids, generation_config=gen_cfg)
-
- # decode output
- response = self.tokenizer.decode(
- output[0].cpu(), skip_special_tokens=True
- )
- response = response[len(prompt) :]
- response = {"long_summary": response}
- print(f"Generated {response=}")
- return {"text": response}
-
-
-# -------------------------------------------------------------------
-# Web API
-# -------------------------------------------------------------------
-
-
-@app.function(
- scaledown_window=60 * 10,
- timeout=60 * 5,
- allow_concurrent_inputs=30,
- secrets=[
- Secret.from_name("reflector-gpu"),
- ],
-)
-@asgi_app()
-def web():
- from fastapi import Depends, FastAPI, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from pydantic import BaseModel
-
- llmstub = LLM()
-
- app = FastAPI()
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
-
- def apikey_auth(apikey: str = Depends(oauth2_scheme)):
- if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid API key",
- headers={"WWW-Authenticate": "Bearer"},
- )
-
- class LLMRequest(BaseModel):
- prompt: str
- gen_schema: Optional[dict] = None
- gen_cfg: Optional[dict] = None
-
- @app.post("/llm", dependencies=[Depends(apikey_auth)])
- def llm(
- req: LLMRequest,
- ):
- gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
- gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
- func = llmstub.generate.spawn(
- prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
- )
- result = func.get()
- return result
-
- return app
diff --git a/server/reflector/llm.py b/server/reflector/llm.py
new file mode 100644
index 00000000..eed50e4a
--- /dev/null
+++ b/server/reflector/llm.py
@@ -0,0 +1,83 @@
+from typing import Type, TypeVar
+
+from llama_index.core import Settings
+from llama_index.core.output_parsers import PydanticOutputParser
+from llama_index.core.program import LLMTextCompletionProgram
+from llama_index.core.response_synthesizers import TreeSummarize
+from llama_index.llms.openai_like import OpenAILike
+from pydantic import BaseModel
+
+T = TypeVar("T", bound=BaseModel)
+
+STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
+Based on the following analysis, provide the information in the requested JSON format:
+
+Analysis:
+{analysis}
+
+{format_instructions}
+"""
+
+
+class LLM:
+ def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
+ self.settings_obj = settings
+ self.model_name = settings.LLM_MODEL
+ self.url = settings.LLM_URL
+ self.api_key = settings.LLM_API_KEY
+ self.context_window = settings.LLM_CONTEXT_WINDOW
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+
+ # Configure llamaindex Settings
+ self._configure_llamaindex()
+
+ def _configure_llamaindex(self):
+ """Configure llamaindex Settings with OpenAILike LLM"""
+ Settings.llm = OpenAILike(
+ model=self.model_name,
+ api_base=self.url,
+ api_key=self.api_key,
+ context_window=self.context_window,
+ is_chat_model=True,
+ is_function_calling_model=False,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ async def get_response(
+ self, prompt: str, texts: list[str], tone_name: str | None = None
+ ) -> str:
+ """Get a text response using TreeSummarize for non-function-calling models"""
+ summarizer = TreeSummarize(verbose=False)
+ response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
+ return str(response).strip()
+
+ async def get_structured_response(
+ self,
+ prompt: str,
+ texts: list[str],
+ output_cls: Type[T],
+ tone_name: str | None = None,
+ ) -> T:
+ """Get structured output from LLM for non-function-calling models"""
+ summarizer = TreeSummarize(verbose=True)
+ response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
+
+ output_parser = PydanticOutputParser(output_cls)
+
+ program = LLMTextCompletionProgram.from_defaults(
+ output_parser=output_parser,
+ prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
+ verbose=False,
+ )
+
+ format_instructions = output_parser.format(
+ "Please structure the above information in the following JSON format:"
+ )
+
+ output = await program.acall(
+ analysis=str(response), format_instructions=format_instructions
+ )
+
+ return output
diff --git a/server/reflector/llm/__init__.py b/server/reflector/llm/__init__.py
deleted file mode 100644
index 446a20d6..00000000
--- a/server/reflector/llm/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .base import LLM # noqa: F401
-from .llm_params import LLMTaskParams # noqa: F401
diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py
deleted file mode 100644
index d26ee36e..00000000
--- a/server/reflector/llm/base.py
+++ /dev/null
@@ -1,347 +0,0 @@
-import importlib
-import json
-import re
-from typing import TypeVar
-
-import nltk
-from prometheus_client import Counter, Histogram
-from transformers import GenerationConfig
-
-from reflector.llm.llm_params import TaskParams
-from reflector.logger import logger as reflector_logger
-from reflector.settings import settings
-from reflector.utils.retry import retry
-
-T = TypeVar("T", bound="LLM")
-
-
-class LLM:
- _nltk_downloaded = False
- _registry = {}
- model_name: str
- m_generate = Histogram(
- "llm_generate",
- "Time spent in LLM.generate",
- ["backend"],
- )
- m_generate_call = Counter(
- "llm_generate_call",
- "Number of calls to LLM.generate",
- ["backend"],
- )
- m_generate_success = Counter(
- "llm_generate_success",
- "Number of successful calls to LLM.generate",
- ["backend"],
- )
- m_generate_failure = Counter(
- "llm_generate_failure",
- "Number of failed calls to LLM.generate",
- ["backend"],
- )
-
- @classmethod
- def ensure_nltk(cls):
- """
- Make sure NLTK package is installed. Searches in the cache and
- downloads only if needed.
- """
- if not cls._nltk_downloaded:
- nltk.download("punkt_tab")
- # For POS tagging
- nltk.download("averaged_perceptron_tagger_eng")
- cls._nltk_downloaded = True
-
- @classmethod
- def register(cls, name, klass):
- cls._registry[name] = klass
-
- @classmethod
- def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
- """
- Return an instance depending on the settings.
- Settings used:
-
- - `LLM_BACKEND`: key of the backend
- - `LLM_URL`: url of the backend
- """
- if name is None:
- name = settings.LLM_BACKEND
- if name not in cls._registry:
- module_name = f"reflector.llm.llm_{name}"
- importlib.import_module(module_name)
- cls.ensure_nltk()
-
- return cls._registry[name](model_name)
-
- def get_model_name(self) -> str:
- """
- Get the currently set model name
- """
- return self._get_model_name()
-
- def _get_model_name(self) -> str:
- pass
-
- def set_model_name(self, model_name: str) -> bool:
- """
- Update the model name with the provided model name
- """
- return self._set_model_name(model_name)
-
- def _set_model_name(self, model_name: str) -> bool:
- raise NotImplementedError
-
- @property
- def template(self) -> str:
- """
- Return the LLM Prompt template
- """
- return """
- ### Human:
- {instruct}
-
- {text}
-
- ### Assistant:
- """
-
- def __init__(self):
- name = self.__class__.__name__
- self.m_generate = self.m_generate.labels(name)
- self.m_generate_call = self.m_generate_call.labels(name)
- self.m_generate_success = self.m_generate_success.labels(name)
- self.m_generate_failure = self.m_generate_failure.labels(name)
- self.detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer()
-
- @property
- def tokenizer(self):
- """
- Return the tokenizer instance used by LLM
- """
- return self._get_tokenizer()
-
- def _get_tokenizer(self):
- pass
-
- def has_structured_output(self):
- # whether implementation supports structured output
- # on the model side (otherwise it's prompt engineering)
- return False
-
- async def generate(
- self,
- prompt: str,
- logger: reflector_logger,
- gen_schema: dict | None = None,
- gen_cfg: GenerationConfig | None = None,
- **kwargs,
- ) -> dict:
- logger.info("LLM generate", prompt=repr(prompt))
-
- if gen_cfg:
- gen_cfg = gen_cfg.to_dict()
- self.m_generate_call.inc()
- try:
- with self.m_generate.time():
- result = await retry(self._generate)(
- prompt=prompt,
- gen_schema=gen_schema,
- gen_cfg=gen_cfg,
- logger=logger,
- **kwargs,
- )
- self.m_generate_success.inc()
-
- except Exception:
- logger.exception("Failed to call llm after retrying")
- self.m_generate_failure.inc()
- raise
-
- logger.debug("LLM result [raw]", result=repr(result))
- if isinstance(result, str):
- result = self._parse_json(result)
- logger.debug("LLM result [parsed]", result=repr(result))
-
- return result
-
- async def completion(
- self, messages: list, logger: reflector_logger, **kwargs
- ) -> dict:
- """
- Use /v1/chat/completion Open-AI compatible endpoint from the URL
- It's up to the user to validate anything or transform the result
- """
- logger.info("LLM completions", messages=messages)
-
- try:
- with self.m_generate.time():
- result = await retry(self._completion)(
- messages=messages, **{**kwargs, "logger": logger}
- )
- self.m_generate_success.inc()
- except Exception:
- logger.exception("Failed to call llm after retrying")
- self.m_generate_failure.inc()
- raise
-
- logger.debug("LLM completion result", result=repr(result))
- return result
-
- def ensure_casing(self, title: str) -> str:
- """
- LLM takes care of word casing, but in rare cases this
- can falter. This is a fallback to ensure the casing of
- topics is in a proper format.
-
- We select nouns, verbs and adjectives and check if camel
- casing is present and fix it, if not. Will not perform
- any other changes.
- """
- tokens = nltk.word_tokenize(title)
- pos_tags = nltk.pos_tag(tokens)
- camel_cased = []
-
- whitelisted_pos_tags = [
- "NN",
- "NNS",
- "NNP",
- "NNPS", # Noun POS
- "VB",
- "VBD",
- "VBG",
- "VBN",
- "VBP",
- "VBZ", # Verb POS
- "JJ",
- "JJR",
- "JJS", # Adjective POS
- ]
-
- # If at all there is an exception, do not block other reflector
- # processes. Return the LLM generated title, at the least.
- try:
- for word, pos in pos_tags:
- if pos in whitelisted_pos_tags and word[0].islower():
- camel_cased.append(word[0].upper() + word[1:])
- else:
- camel_cased.append(word)
- modified_title = self.detokenizer.detokenize(camel_cased)
-
- # Irrespective of casing changes, the starting letter
- # of title is always upper-cased
- title = modified_title[0].upper() + modified_title[1:]
- except Exception as e:
- reflector_logger.info(
- f"Failed to ensure casing on {title=} with exception : {str(e)}"
- )
-
- return title
-
- def trim_title(self, title: str) -> str:
- """
- List of manual trimming to the title.
-
- Longer titles are prone to run into A prefix of phrases that don't
- really add any descriptive information and in some cases, this
- behaviour can be repeated for several consecutive topics. Trim the
- titles to maintain quality of titles.
- """
- phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"]
- try:
- pattern = (
- r"\b(?:"
- + "|".join(re.escape(phrase) for phrase in phrases_to_remove)
- + r")\b"
- )
- title = re.sub(pattern, "", title, flags=re.IGNORECASE)
- except Exception as e:
- reflector_logger.info(f"Failed to trim {title=} with exception : {str(e)}")
- return title
-
- async def _generate(
- self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
- ) -> str:
- raise NotImplementedError
-
- async def _completion(self, messages: list, **kwargs) -> dict:
- raise NotImplementedError
-
- def _parse_json(self, result: str) -> dict:
- result = result.strip()
- # try detecting code block if exist
- # starts with ```json\n, ends with ```
- # or starts with ```\n, ends with ```
- # or starts with \n```javascript\n, ends with ```
-
- regex = r"```(json|javascript|)?(.*)```"
- matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
- if matches:
- result = matches[0][1]
-
- else:
- # maybe the prompt has been started with ```json
- # so if text ends with ```, just remove it and use it as json
- if result.endswith("```"):
- result = result[:-3]
-
- return json.loads(result.strip())
-
- def text_token_threshold(self, task_params: TaskParams | None) -> int:
- """
- Choose the token size to set as the threshold to pack the LLM calls
- """
- buffer_token_size = 100
- default_output_tokens = 1000
- context_window = self.tokenizer.model_max_length
- tokens = self.tokenizer.tokenize(
- self.create_prompt(instruct=task_params.instruct, text="")
- )
- threshold = context_window - len(tokens) - buffer_token_size
- if task_params.gen_cfg:
- threshold -= task_params.gen_cfg.max_new_tokens
- else:
- threshold -= default_output_tokens
- return threshold
-
- def split_corpus(
- self,
- corpus: str,
- task_params: TaskParams,
- token_threshold: int | None = None,
- ) -> list[str]:
- """
- Split the input to the LLM due to CUDA memory limitations and LLM context window
- restrictions.
-
- Accumulate tokens from full sentences till threshold and yield accumulated
- tokens. Reset accumulation when threshold is reached and repeat process.
- """
- if not token_threshold:
- token_threshold = self.text_token_threshold(task_params=task_params)
-
- accumulated_tokens = []
- accumulated_sentences = []
- accumulated_token_count = 0
- corpus_sentences = nltk.sent_tokenize(corpus)
-
- for sentence in corpus_sentences:
- tokens = self.tokenizer.tokenize(sentence)
- if accumulated_token_count + len(tokens) <= token_threshold:
- accumulated_token_count += len(tokens)
- accumulated_tokens.extend(tokens)
- accumulated_sentences.append(sentence)
- else:
- yield "".join(accumulated_sentences)
- accumulated_token_count = len(tokens)
- accumulated_tokens = tokens
- accumulated_sentences = [sentence]
-
- if accumulated_tokens:
- yield " ".join(accumulated_sentences)
-
- def create_prompt(self, instruct: str, text: str) -> str:
- """
- Create a consumable prompt based on the prompt template
- """
- return self.template.format(instruct=instruct, text=text)
diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py
deleted file mode 100644
index 327300c7..00000000
--- a/server/reflector/llm/llm_modal.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import httpx
-from transformers import AutoTokenizer, GenerationConfig
-
-from reflector.llm.base import LLM
-from reflector.logger import logger as reflector_logger
-from reflector.settings import settings
-from reflector.utils.retry import retry
-
-
-class ModalLLM(LLM):
- def __init__(self, model_name: str | None = None):
- super().__init__()
- self.timeout = settings.LLM_TIMEOUT
- self.llm_url = settings.LLM_URL + "/llm"
- self.headers = {
- "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
- }
- self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
-
- @property
- def supported_models(self):
- """
- List of currently supported models on this GPU platform
- """
- # TODO: Query the specific GPU platform
- # Replace this with a HTTP call
- return [
- "lmsys/vicuna-13b-v1.5",
- "HuggingFaceH4/zephyr-7b-alpha",
- "NousResearch/Hermes-3-Llama-3.1-8B",
- ]
-
- async def _generate(
- self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
- ) -> str:
- json_payload = {"prompt": prompt}
- if gen_schema:
- json_payload["gen_schema"] = gen_schema
- if gen_cfg:
- json_payload["gen_cfg"] = gen_cfg
-
- # Handing over generation of the final summary to Zephyr model
- # but replacing the Vicuna model will happen after more testing
- # TODO: Create a mapping of model names and cloud deployments
- if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
- self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
-
- async with httpx.AsyncClient() as client:
- response = await retry(client.post)(
- self.llm_url,
- headers=self.headers,
- json=json_payload,
- timeout=self.timeout,
- retry_timeout=60 * 5,
- follow_redirects=True,
- logger=kwargs.get("logger", reflector_logger),
- )
- response.raise_for_status()
- text = response.json()["text"]
- return text
-
- async def _completion(self, messages: list, **kwargs) -> dict:
- # returns full api response
- kwargs.setdefault("temperature", 0.3)
- kwargs.setdefault("max_tokens", 2048)
- kwargs.setdefault("stream", False)
- kwargs.setdefault("repetition_penalty", 1)
- kwargs.setdefault("top_p", 1)
- kwargs.setdefault("top_k", -1)
- kwargs.setdefault("min_p", 0.05)
- data = {"messages": messages, "model": self.model_name, **kwargs}
-
- if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B":
- self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions"
-
- async with httpx.AsyncClient() as client:
- response = await retry(client.post)(
- self.llm_url,
- headers=self.headers,
- json=data,
- timeout=self.timeout,
- retry_timeout=60 * 5,
- follow_redirects=True,
- logger=kwargs.get("logger", reflector_logger),
- )
- response.raise_for_status()
- return response.json()
-
- def _set_model_name(self, model_name: str) -> bool:
- """
- Set the model name
- """
- # Abort, if the model is not supported
- if model_name not in self.supported_models:
- reflector_logger.info(
- f"Attempted to change {model_name=}, but is not supported."
- f"Setting model and tokenizer failed !"
- )
- return False
- # Abort, if the model is already set
- elif hasattr(self, "model_name") and model_name == self._get_model_name():
- reflector_logger.info("No change in model. Setting model skipped.")
- return False
- # Update model name and tokenizer
- self.model_name = model_name
- self.llm_tokenizer = AutoTokenizer.from_pretrained(
- self.model_name, cache_dir=settings.CACHE_DIR
- )
- reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
- return True
-
- def _get_tokenizer(self) -> AutoTokenizer:
- """
- Return the currently used LLM tokenizer
- """
- return self.llm_tokenizer
-
- def _get_model_name(self) -> str:
- """
- Return the current model name from the instance details
- """
- return self.model_name
-
-
-LLM.register("modal", ModalLLM)
-
-if __name__ == "__main__":
- from reflector.logger import logger
-
- async def main():
- llm = ModalLLM()
- prompt = llm.create_prompt(
- instruct="Complete the following task",
- text="Tell me a joke about programming.",
- )
- result = await llm.generate(prompt=prompt, logger=logger)
- print(result)
-
- gen_schema = {
- "type": "object",
- "properties": {"response": {"type": "string"}},
- }
-
- result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
- print(result)
-
- gen_cfg = GenerationConfig(max_new_tokens=150)
- result = await llm.generate(
- prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
- )
- print(result)
-
- import asyncio
-
- asyncio.run(main())
diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py
deleted file mode 100644
index e28211ef..00000000
--- a/server/reflector/llm/llm_openai.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import httpx
-from transformers import GenerationConfig
-
-from reflector.llm.base import LLM
-from reflector.logger import logger
-from reflector.settings import settings
-
-
-class OpenAILLM(LLM):
- def __init__(self, model_name: str | None = None, **kwargs):
- super().__init__(**kwargs)
- self.openai_key = settings.LLM_OPENAI_KEY
- self.openai_url = settings.LLM_URL
- self.openai_model = settings.LLM_OPENAI_MODEL
- self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
- self.timeout = settings.LLM_TIMEOUT
- self.max_tokens = settings.LLM_MAX_TOKENS
- logger.info(f"LLM use openai backend at {self.openai_url}")
-
- async def _generate(
- self,
- prompt: str,
- gen_schema: dict | None,
- gen_cfg: GenerationConfig | None,
- **kwargs,
- ) -> str:
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.openai_key}",
- }
-
- async with httpx.AsyncClient(timeout=self.timeout) as client:
- response = await client.post(
- self.openai_url,
- headers=headers,
- json={
- "model": self.openai_model,
- "prompt": prompt,
- "max_tokens": self.max_tokens,
- "temperature": self.openai_temperature,
- },
- )
- response.raise_for_status()
- result = response.json()
- return result["choices"][0]["text"]
-
-
-LLM.register("openai", OpenAILLM)
diff --git a/server/reflector/llm/llm_params.py b/server/reflector/llm/llm_params.py
deleted file mode 100644
index fbe73bd9..00000000
--- a/server/reflector/llm/llm_params.py
+++ /dev/null
@@ -1,219 +0,0 @@
-from typing import Optional, TypeVar
-
-from pydantic import BaseModel
-from transformers import GenerationConfig
-
-
-class TaskParams(BaseModel, arbitrary_types_allowed=True):
- instruct: str
- gen_cfg: Optional[GenerationConfig] = None
- gen_schema: Optional[dict] = None
-
-
-T = TypeVar("T", bound="LLMTaskParams")
-
-
-class LLMTaskParams:
- _registry = {}
-
- @classmethod
- def register(cls, task, klass) -> None:
- cls._registry[task] = klass
-
- @classmethod
- def get_instance(cls, task: str) -> T:
- return cls._registry[task]()
-
- @property
- def task_params(self) -> TaskParams | None:
- """
- Fetch the task related parameters
- """
- return self._get_task_params()
-
- def _get_task_params(self) -> None:
- pass
-
-
-class FinalLongSummaryParams(LLMTaskParams):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._gen_cfg = GenerationConfig(
- max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3
- )
- self._instruct = """
- Take the key ideas and takeaways from the text and create a short
- summary. Be sure to keep the length of the response to a minimum.
- Do not include trivial information in the summary.
- """
- self._schema = {
- "type": "object",
- "properties": {"long_summary": {"type": "string"}},
- }
- self._task_params = TaskParams(
- instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
- )
-
- def _get_task_params(self) -> TaskParams:
- """gen_schema
- Return the parameters associated with a specific LLM task
- """
- return self._task_params
-
-
-class FinalShortSummaryParams(LLMTaskParams):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._gen_cfg = GenerationConfig(
- max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
- )
- self._instruct = """
- Take the key ideas and takeaways from the text and create a short
- summary. Be sure to keep the length of the response to a minimum.
- Do not include trivial information in the summary.
- """
- self._schema = {
- "type": "object",
- "properties": {"short_summary": {"type": "string"}},
- }
- self._task_params = TaskParams(
- instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
- )
-
- def _get_task_params(self) -> TaskParams:
- """
- Return the parameters associated with a specific LLM task
- """
- return self._task_params
-
-
-class FinalTitleParams(LLMTaskParams):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._gen_cfg = GenerationConfig(
- max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
- )
- self._instruct = """
- Combine the following individual titles into one single short title that
- condenses the essence of all titles.
- """
- self._schema = {
- "type": "object",
- "properties": {"title": {"type": "string"}},
- }
- self._task_params = TaskParams(
- instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
- )
-
- def _get_task_params(self) -> TaskParams:
- """
- Return the parameters associated with a specific LLM task
- """
- return self._task_params
-
-
-class TopicParams(LLMTaskParams):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._gen_cfg = GenerationConfig(
- max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9
- )
- self._instruct = """
- Create a JSON object as response.The JSON object must have 2 fields:
- i) title and ii) summary.
- For the title field, generate a very detailed and self-explanatory
- title for the given text. Let the title be as descriptive as possible.
- For the summary field, summarize the given text in a maximum of
- two sentences.
- """
- self._schema = {
- "type": "object",
- "properties": {
- "title": {"type": "string"},
- "summary": {"type": "string"},
- },
- }
- self._task_params = TaskParams(
- instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
- )
-
- def _get_task_params(self) -> TaskParams:
- """
- Return the parameters associated with a specific LLM task
- """
- return self._task_params
-
-
-class BulletedSummaryParams(LLMTaskParams):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._gen_cfg = GenerationConfig(
- max_new_tokens=800,
- num_beams=1,
- do_sample=True,
- temperature=0.2,
- early_stopping=True,
- )
- self._instruct = """
- Given a meeting transcript, extract the key things discussed in the
- form of a list.
-
- While generating the response, follow the constraints mentioned below.
-
- Summary constraints:
- i) Do not add new content, except to fix spelling or punctuation.
- ii) Do not add any prefixes or numbering in the response.
- iii) The summarization should be as information dense as possible.
- iv) Do not add any additional sections like Note, Conclusion, etc. in
- the response.
-
- Response format:
- i) The response should be in the form of a bulleted list.
- ii) Iteratively merge all the relevant paragraphs together to keep the
- number of paragraphs to a minimum.
- iii) Remove any unfinished sentences from the final response.
- iv) Do not include narrative or reporting clauses.
- v) Use "*" as the bullet icon.
- """
- self._task_params = TaskParams(
- instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
- )
-
- def _get_task_params(self) -> TaskParams:
- """gen_schema
- Return the parameters associated with a specific LLM task
- """
- return self._task_params
-
-
-class MergedSummaryParams(LLMTaskParams):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._gen_cfg = GenerationConfig(
- max_new_tokens=600,
- num_beams=1,
- do_sample=True,
- temperature=0.2,
- early_stopping=True,
- )
- self._instruct = """
- Given the key points of a meeting, summarize the points to describe the
- meeting in the form of paragraphs.
- """
- self._task_params = TaskParams(
- instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
- )
-
- def _get_task_params(self) -> TaskParams:
- """gen_schema
- Return the parameters associated with a specific LLM task
- """
- return self._task_params
-
-
-LLMTaskParams.register("topic", TopicParams)
-LLMTaskParams.register("final_title", FinalTitleParams)
-LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
-LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
-LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
-LLMTaskParams.register("merged_summary", MergedSummaryParams)
diff --git a/server/reflector/llm/openai_llm.py b/server/reflector/llm/openai_llm.py
deleted file mode 100644
index 90f31869..00000000
--- a/server/reflector/llm/openai_llm.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import httpx
-from transformers import AutoTokenizer
-
-from reflector.logger import logger
-
-
-def apply_gen_config(payload: dict, gen_cfg) -> None:
- """Apply generation config overrides to the payload."""
- config_mapping = {
- "temperature": "temperature",
- "max_new_tokens": "max_tokens",
- "max_tokens": "max_tokens",
- "top_p": "top_p",
- "frequency_penalty": "frequency_penalty",
- "presence_penalty": "presence_penalty",
- }
-
- for cfg_attr, payload_key in config_mapping.items():
- value = getattr(gen_cfg, cfg_attr, None)
- if value is not None:
- payload[payload_key] = value
- if cfg_attr == "max_new_tokens": # Handle max_new_tokens taking precedence
- break
-
-
-class OpenAILLM:
- def __init__(self, config_prefix: str, settings):
- self.config_prefix = config_prefix
- self.settings_obj = settings
- self.model_name = getattr(settings, f"{config_prefix}_MODEL")
- self.url = getattr(settings, f"{config_prefix}_LLM_URL")
- self.api_key = getattr(settings, f"{config_prefix}_LLM_API_KEY")
-
- timeout = getattr(settings, f"{config_prefix}_LLM_TIMEOUT", 300)
- self.temperature = getattr(settings, f"{config_prefix}_LLM_TEMPERATURE", 0.7)
- self.max_tokens = getattr(settings, f"{config_prefix}_LLM_MAX_TOKENS", 1024)
- self.client = httpx.AsyncClient(timeout=timeout)
-
- # Use a tokenizer that approximates OpenAI token counting
- tokenizer_name = getattr(settings, f"{config_prefix}_TOKENIZER", "gpt2")
- try:
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- except Exception:
- logger.debug(
- f"Failed to load tokenizer '{tokenizer_name}', falling back to default 'gpt2' tokenizer"
- )
- self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
-
- async def generate(
- self, prompt: str, gen_schema=None, gen_cfg=None, logger=None
- ) -> str:
- if logger:
- logger.debug(
- "OpenAI LLM generate",
- prompt=repr(prompt[:100] + "..." if len(prompt) > 100 else prompt),
- )
-
- messages = [{"role": "user", "content": prompt}]
- result = await self.completion(
- messages, gen_schema=gen_schema, gen_cfg=gen_cfg, logger=logger
- )
- return result["choices"][0]["message"]["content"]
-
- async def completion(
- self, messages: list, gen_schema=None, gen_cfg=None, logger=None, **kwargs
- ) -> dict:
- if logger:
- logger.info("OpenAI LLM completion", messages_count=len(messages))
-
- payload = {
- "model": self.model_name,
- "messages": messages,
- "temperature": self.temperature,
- "max_tokens": self.max_tokens,
- }
-
- # Apply generation config overrides
- if gen_cfg:
- apply_gen_config(payload, gen_cfg)
-
- # Apply structured output schema
- if gen_schema:
- payload["response_format"] = {
- "type": "json_schema",
- "json_schema": {"name": "response", "schema": gen_schema},
- }
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}",
- }
-
- url = f"{self.url.rstrip('/')}/chat/completions"
-
- if logger:
- logger.debug(
- "OpenAI API request", url=url, payload_keys=list(payload.keys())
- )
-
- response = await self.client.post(url, json=payload, headers=headers)
- response.raise_for_status()
-
- result = response.json()
-
- if logger:
- logger.debug(
- "OpenAI API response",
- status_code=response.status_code,
- choices_count=len(result.get("choices", [])),
- )
-
- return result
-
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self.client.aclose()
diff --git a/server/reflector/processors/summary/summary_builder.py b/server/reflector/processors/summary/summary_builder.py
index c6744183..ccbae48a 100644
--- a/server/reflector/processors/summary/summary_builder.py
+++ b/server/reflector/processors/summary/summary_builder.py
@@ -12,15 +12,9 @@ from textwrap import dedent
from typing import Type, TypeVar
import structlog
-from llama_index.core import Settings
-from llama_index.core.output_parsers import PydanticOutputParser
-from llama_index.core.program import LLMTextCompletionProgram
-from llama_index.core.response_synthesizers import TreeSummarize
-from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel, Field
-from reflector.llm.base import LLM
-from reflector.llm.openai_llm import OpenAILLM
+from reflector.llm import LLM
from reflector.settings import settings
T = TypeVar("T", bound=BaseModel)
@@ -168,23 +162,12 @@ class SummaryBuilder:
self.summaries: list[dict[str, str]] = []
self.subjects: list[str] = []
self.transcription_type: TranscriptionType | None = None
- self.llm_instance: LLM = llm
+ self.llm: LLM = llm
self.model_name: str = llm.model_name
self.logger = logger or structlog.get_logger()
if filename:
self.read_transcript_from_file(filename)
- Settings.llm = OpenAILike(
- model=llm.model_name,
- api_base=llm.url,
- api_key=llm.api_key,
- context_window=settings.SUMMARY_LLM_CONTEXT_SIZE_TOKENS,
- is_chat_model=True,
- is_function_calling_model=llm.has_structured_output,
- temperature=llm.temperature,
- max_tokens=llm.max_tokens,
- )
-
def read_transcript_from_file(self, filename: str) -> None:
"""
Load a transcript from a text file.
@@ -202,40 +185,16 @@ class SummaryBuilder:
self.transcript = transcript
def set_llm_instance(self, llm: LLM) -> None:
- self.llm_instance = llm
+ self.llm = llm
async def _get_structured_response(
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
- ) -> Type[T]:
+ ) -> T:
"""Generic function to get structured output from LLM for non-function-calling models."""
- # First, use TreeSummarize to get the response
- summarizer = TreeSummarize(verbose=True)
-
- response = await summarizer.aget_response(
- prompt, [self.transcript], tone_name=tone_name
+ return await self.llm.get_structured_response(
+ prompt, [self.transcript], output_cls, tone_name=tone_name
)
- # Then, use PydanticOutputParser to structure the response
- output_parser = PydanticOutputParser(output_cls)
-
- prompt_template_str = STRUCTURED_RESPONSE_PROMPT_TEMPLATE
-
- program = LLMTextCompletionProgram.from_defaults(
- output_parser=output_parser,
- prompt_template_str=prompt_template_str,
- verbose=False,
- )
-
- format_instructions = output_parser.format(
- "Please structure the above information in the following JSON format:"
- )
-
- output = await program.acall(
- analysis=str(response), format_instructions=format_instructions
- )
-
- return output
-
# ----------------------------------------------------------------------------
# Participants
# ----------------------------------------------------------------------------
@@ -354,19 +313,18 @@ class SummaryBuilder:
async def generate_subject_summaries(self) -> None:
"""Generate detailed summaries for each extracted subject."""
assert self.transcript is not None
- summarizer = TreeSummarize(verbose=False)
summaries = []
for subject in self.subjects:
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
- detailed_response = await summarizer.aget_response(
+ detailed_response = await self.llm.get_response(
detailed_prompt, [self.transcript], tone_name="Topic assistant"
)
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
- paragraph_response = await summarizer.aget_response(
+ paragraph_response = await self.llm.get_response(
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
)
@@ -377,7 +335,6 @@ class SummaryBuilder:
async def generate_recap(self) -> None:
"""Generate a quick recap from the subject summaries."""
- summarizer = TreeSummarize(verbose=True)
summaries_text = "\n\n".join(
[
@@ -388,7 +345,7 @@ class SummaryBuilder:
recap_prompt = RECAP_PROMPT
- recap_response = await summarizer.aget_response(
+ recap_response = await self.llm.get_response(
recap_prompt, [summaries_text], tone_name="Recap summarizer"
)
@@ -483,7 +440,7 @@ if __name__ == "__main__":
async def main():
# build the summary
- llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
+ llm = LLM(settings=settings)
sm = SummaryBuilder(llm=llm, filename=args.transcript)
if args.subjects:
diff --git a/server/reflector/processors/transcript_final_summary.py b/server/reflector/processors/transcript_final_summary.py
index 9cfc4a00..0b4a594c 100644
--- a/server/reflector/processors/transcript_final_summary.py
+++ b/server/reflector/processors/transcript_final_summary.py
@@ -1,4 +1,4 @@
-from reflector.llm.openai_llm import OpenAILLM
+from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.summary.summary_builder import SummaryBuilder
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
@@ -17,7 +17,7 @@ class TranscriptFinalSummaryProcessor(Processor):
super().__init__(**kwargs)
self.transcript = transcript
self.chunks: list[TitleSummary] = []
- self.llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
+ self.llm = LLM(settings=settings)
self.builder = None
async def _push(self, data: TitleSummary):
diff --git a/server/reflector/processors/transcript_final_title.py b/server/reflector/processors/transcript_final_title.py
index 4b486c08..75b62b5a 100644
--- a/server/reflector/processors/transcript_final_title.py
+++ b/server/reflector/processors/transcript_final_title.py
@@ -1,67 +1,72 @@
-from reflector.llm import LLM, LLMTaskParams
+from textwrap import dedent
+
+from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import FinalTitle, TitleSummary
+from reflector.settings import settings
+from reflector.utils.text import clean_title
+
+TITLE_PROMPT = dedent(
+ """
+ Generate a concise title for this meeting based on the following topic titles.
+ Ignore casual conversation, greetings, or administrative matters.
+
+ The title must:
+ - Be maximum 10 words
+ - Use noun phrases when possible (e.g., "Q1 Budget Review" not "Reviewing the Q1 Budget")
+ - Avoid generic terms like "Team Meeting" or "Discussion"
+
+ If multiple unrelated topics were discussed, prioritize the most significant one.
+ or create a compound title (e.g., "Product Launch and Budget Planning").
+
+
+ {titles}
+
+
+ Do not explain, just output the meeting title as a single line.
+ """
+).strip()
class TranscriptFinalTitleProcessor(Processor):
"""
- Assemble all summary into a line-based json
+ Generate a final title from topic titles using LlamaIndex
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalTitle
- TASK = "final_title"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
- self.llm = LLM.get_instance()
- self.params = LLMTaskParams.get_instance(self.TASK).task_params
+ self.llm = LLM(settings=settings, temperature=0.5, max_tokens=200)
async def _push(self, data: TitleSummary):
self.chunks.append(data)
- async def get_title(self, text: str) -> dict:
+ async def get_title(self, accumulated_titles: str) -> str:
"""
- Generate a title for the whole recording
+ Generate a title for the whole recording using LLM
"""
- chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
+ prompt = TITLE_PROMPT.format(titles=accumulated_titles)
+ response = await self.llm.get_response(
+ prompt,
+ [accumulated_titles],
+ tone_name="Title generator",
+ )
- if len(chunks) == 1:
- chunk = chunks[0]
- prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
- title_result = await self.llm.generate(
- prompt=prompt,
- gen_schema=self.params.gen_schema,
- gen_cfg=self.params.gen_cfg,
- logger=self.logger,
- )
- return title_result
- else:
- accumulated_titles = ""
- for chunk in chunks:
- prompt = self.llm.create_prompt(
- instruct=self.params.instruct, text=chunk
- )
- title_result = await self.llm.generate(
- prompt=prompt,
- gen_schema=self.params.gen_schema,
- gen_cfg=self.params.gen_cfg,
- logger=self.logger,
- )
- accumulated_titles += title_result["title"]
+ self.logger.info(f"Generated title response: {response}")
- return await self.get_title(accumulated_titles)
+ return response
async def _flush(self):
if not self.chunks:
self.logger.warning("No summary to output")
return
- accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
- title_result = await self.get_title(accumulated_titles)
- final_title = self.llm.trim_title(title_result["title"])
- final_title = self.llm.ensure_casing(final_title)
+ accumulated_titles = "\n".join([f"- {chunk.title}" for chunk in self.chunks])
+ title = await self.get_title(accumulated_titles)
+ title = clean_title(title)
- final_title = FinalTitle(title=final_title)
+ final_title = FinalTitle(title=title)
await self.emit(final_title)
diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py
index dd14ce93..e0e306ce 100644
--- a/server/reflector/processors/transcript_topic_detector.py
+++ b/server/reflector/processors/transcript_topic_detector.py
@@ -1,7 +1,41 @@
-from reflector.llm import LLM, LLMTaskParams
+from textwrap import dedent
+
+from pydantic import BaseModel, Field
+
+from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.settings import settings
+from reflector.utils.text import clean_title
+
+TOPIC_PROMPT = dedent(
+ """
+ Analyze the following transcript segment and extract the main topic being discussed.
+ Focus on the substantive content and ignore small talk or administrative chatter.
+
+ Create a title that:
+ - Captures the specific subject matter being discussed
+ - Is descriptive and self-explanatory
+ - Uses professional language
+ - Is specific rather than generic
+
+ For the summary:
+ - Summarize the key points in maximum two sentences
+ - Focus on what was discussed, decided, or accomplished
+ - Be concise but informative
+
+
+ {text}
+
+ """
+).strip()
+
+
+class TopicResponse(BaseModel):
+ """Structured response for topic detection"""
+
+ title: str = Field(description="A descriptive title for the topic being discussed")
+ summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
class TranscriptTopicDetectorProcessor(Processor):
@@ -11,7 +45,6 @@ class TranscriptTopicDetectorProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = TitleSummary
- TASK = "topic"
def __init__(
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
@@ -19,8 +52,7 @@ class TranscriptTopicDetectorProcessor(Processor):
super().__init__(**kwargs)
self.transcript = None
self.min_transcript_length = min_transcript_length
- self.llm = LLM.get_instance()
- self.params = LLMTaskParams.get_instance(self.TASK).task_params
+ self.llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
async def _push(self, data: Transcript):
if self.transcript is None:
@@ -34,18 +66,15 @@ class TranscriptTopicDetectorProcessor(Processor):
return
await self.flush()
- async def get_topic(self, text: str) -> dict:
+ async def get_topic(self, text: str) -> TopicResponse:
"""
- Generate a topic and description for a transcription excerpt
+ Generate a topic and description for a transcription excerpt using LLM
"""
- prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
- topic_result = await self.llm.generate(
- prompt=prompt,
- gen_schema=self.params.gen_schema,
- gen_cfg=self.params.gen_cfg,
- logger=self.logger,
+ prompt = TOPIC_PROMPT.format(text=text)
+ response = await self.llm.get_structured_response(
+ prompt, [text], TopicResponse, tone_name="Topic analyzer"
)
- return topic_result
+ return response
async def _flush(self):
if not self.transcript:
@@ -53,13 +82,13 @@ class TranscriptTopicDetectorProcessor(Processor):
text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript")
+
topic_result = await self.get_topic(text=text)
- title = self.llm.trim_title(topic_result["title"])
- title = self.llm.ensure_casing(title)
+ title = clean_title(topic_result.title)
summary = TitleSummary(
title=title,
- summary=topic_result["summary"],
+ summary=topic_result.summary,
timestamp=self.transcript.timestamp,
duration=self.transcript.duration,
transcript=self.transcript,
diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py
index 8cb9b676..2489e563 100644
--- a/server/reflector/processors/transcript_translator.py
+++ b/server/reflector/processors/transcript_translator.py
@@ -13,14 +13,13 @@ class TranscriptTranslatorProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = Transcript
- TASK = "translate"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transcript = None
self.translate_url = settings.TRANSLATE_URL
self.timeout = settings.TRANSLATE_TIMEOUT
- self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
+ self.headers = {"Authorization": f"Bearer {settings.TRANSCRIPT_MODAL_API_KEY}"}
async def _push(self, data: Transcript):
self.transcript = data
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index 6cd54be5..30af270b 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -9,13 +9,14 @@ class Settings(BaseSettings):
)
# CORS
+ UI_BASE_URL: str = "http://localhost:3000"
CORS_ORIGIN: str = "*"
CORS_ALLOW_CREDENTIALS: bool = False
# Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
- # local data directory (audio for no)
+ # local data directory
DATA_DIR: str = "./data"
# Audio Transcription
@@ -24,10 +25,6 @@ class Settings(BaseSettings):
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
- # Translate into the target language
- TRANSLATE_URL: str | None = None
- TRANSLATE_TIMEOUT: int = 90
-
# Audio transcription modal.com configuration
TRANSCRIPT_MODAL_API_KEY: str | None = None
@@ -40,31 +37,15 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
+ # Translate into the target language
+ TRANSLATE_URL: str | None = None
+ TRANSLATE_TIMEOUT: int = 90
+
# LLM
- # available backend: openai, modal
- LLM_BACKEND: str = "modal"
-
- # LLM common configuration
+ LLM_MODEL: str = "microsoft/phi-4"
LLM_URL: str | None = None
- LLM_HOST: str = "localhost"
- LLM_PORT: int = 7860
- LLM_OPENAI_KEY: str | None = None
- LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
- LLM_OPENAI_TEMPERATURE: float = 0.7
- LLM_TIMEOUT: int = 60 * 5 # take cold start into account
- LLM_MAX_TOKENS: int = 1024
- LLM_TEMPERATURE: float = 0.7
- ZEPHYR_LLM_URL: str | None = None
- HERMES_3_8B_LLM_URL: str | None = None
-
- # LLM Modal configuration
- LLM_MODAL_API_KEY: str | None = None
-
- # per-task cases
- SUMMARY_MODEL: str = "monadical/private/smart"
- SUMMARY_LLM_URL: str | None = None
- SUMMARY_LLM_API_KEY: str | None = None
- SUMMARY_LLM_CONTEXT_SIZE_TOKENS: int = 16000
+ LLM_API_KEY: str | None = None
+ LLM_CONTEXT_WINDOW: int = 16000
# Diarization
DIARIZATION_ENABLED: bool = True
@@ -86,12 +67,6 @@ class Settings(BaseSettings):
# if set, all anonymous record will be public
PUBLIC_MODE: bool = False
- # Default LLM model name
- DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
-
- # Cache directory for all model storage
- CACHE_DIR: str = "./data"
-
# Min transcript length to generate topic + summary
MIN_TRANSCRIPT_LENGTH: int = 750
@@ -116,24 +91,20 @@ class Settings(BaseSettings):
# Healthcheck
HEALTHCHECK_URL: str | None = None
- AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
- SQS_POLLING_TIMEOUT_SECONDS: int = 60
-
+ # Whereby integration
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
-
WHEREBY_API_KEY: str | None = None
-
+ WHEREBY_WEBHOOK_SECRET: str | None = None
AWS_WHEREBY_S3_BUCKET: str | None = None
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
+ AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
+ SQS_POLLING_TIMEOUT_SECONDS: int = 60
+ # Zulip integration
ZULIP_REALM: str | None = None
ZULIP_API_KEY: str | None = None
ZULIP_BOT_EMAIL: str | None = None
- UI_BASE_URL: str = "http://localhost:3000"
-
- WHEREBY_WEBHOOK_SECRET: str | None = None
-
settings = Settings()
diff --git a/server/reflector/utils/text.py b/server/reflector/utils/text.py
new file mode 100644
index 00000000..ea5dd7e3
--- /dev/null
+++ b/server/reflector/utils/text.py
@@ -0,0 +1,33 @@
+def clean_title(title: str) -> str:
+ """
+ Clean and format a title string for consistent capitalization.
+
+ Rules:
+ - Strip surrounding quotes (single or double)
+ - Capitalize the first word
+ - Capitalize words longer than 3 characters
+ - Keep words with 3 or fewer characters lowercase (except first word)
+
+ Args:
+ title: The title string to clean
+
+ Returns:
+ The cleaned title with consistent capitalization
+
+ Examples:
+ >>> clean_title("hello world")
+ "Hello World"
+ >>> clean_title("meeting with the team")
+ "Meeting With the Team"
+ >>> clean_title("'Title with quotes'")
+ "Title With Quotes"
+ """
+ title = title.strip("\"'")
+ words = title.split()
+ if words:
+ words = [
+ word.capitalize() if i == 0 or len(word) > 3 else word.lower()
+ for i, word in enumerate(words)
+ ]
+ title = " ".join(words)
+ return title
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index 2707daed..434e7dea 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -37,8 +37,12 @@ def dummy_processors():
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
) as mock_translate,
):
- mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
- mock_title.return_value = {"title": "LLM TITLE"}
+ from reflector.processors.transcript_topic_detector import TopicResponse
+
+ mock_topic.return_value = TopicResponse(
+ title="LLM TITLE", summary="LLM SUMMARY"
+ )
+ mock_title.return_value = "LLM Title"
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = "LLM SHORT SUMMARY"
mock_translate.return_value = "Bonjour le monde"
@@ -103,14 +107,15 @@ async def dummy_diarization():
@pytest.fixture
async def dummy_llm():
- from reflector.llm.base import LLM
+ from reflector.llm import LLM
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
- with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
+ # LLM doesn't have get_instance anymore, mocking constructor instead
+ with patch("reflector.llm.LLM") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@@ -129,22 +134,19 @@ async def dummy_storage():
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
- with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
- mock_storage.return_value = DummyStorage()
- yield
+ async def _get_file(self, *args, **kwargs):
+ from pathlib import Path
+ test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
+ return test_mp3.read_bytes()
-@pytest.fixture
-def nltk():
- with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
- mock_nltk.return_value = "NLTK PACKAGE"
- yield
-
-
-@pytest.fixture
-def ensure_casing():
- with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
- mock_casing.return_value = "LLM TITLE"
+ dummy = DummyStorage()
+ with (
+ patch("reflector.storage.base.Storage.get_instance") as mock_storage,
+ patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
+ ):
+ mock_storage.return_value = dummy
+ mock_get_transcripts.return_value = dummy
yield
diff --git a/server/tests/test_processors_broadcast.py b/server/tests/test_processors_broadcast.py
index 5480de36..197ebac4 100644
--- a/server/tests/test_processors_broadcast.py
+++ b/server/tests/test_processors_broadcast.py
@@ -2,7 +2,7 @@ import pytest
@pytest.mark.asyncio
-async def test_processor_broadcast(nltk):
+async def test_processor_broadcast():
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
class TestProcessor(Processor):
diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py
index 0124d5ba..a1787c49 100644
--- a/server/tests/test_processors_pipeline.py
+++ b/server/tests/test_processors_pipeline.py
@@ -3,11 +3,9 @@ import pytest
@pytest.mark.asyncio
async def test_basic_process(
- nltk,
dummy_transcript,
dummy_llm,
dummy_processors,
- ensure_casing,
):
# goal is to start the server, and send rtc audio to it
# validate the events received
@@ -16,8 +14,8 @@ async def test_basic_process(
from reflector.settings import settings
from reflector.tools.process import process_audio_file
- # use an LLM test backend
- settings.LLM_BACKEND = "test"
+ # LLM_BACKEND no longer exists in settings
+ # settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper"
# event callback
diff --git a/server/tests/test_transcripts_process.py b/server/tests/test_transcripts_process.py
index ef105c85..e4973acb 100644
--- a/server/tests/test_transcripts_process.py
+++ b/server/tests/test_transcripts_process.py
@@ -10,7 +10,6 @@ from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
- ensure_casing,
dummy_llm,
dummy_processors,
dummy_diarization,
@@ -69,7 +68,7 @@ async def test_transcript_process(
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
- assert transcript["title"] == "LLM TITLE"
+ assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index 24b70d6f..a8406337 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -69,8 +69,6 @@ async def test_transcript_rtc_and_websocket(
dummy_diarization,
dummy_storage,
fake_mp3_upload,
- ensure_casing,
- nltk,
appserver,
):
# goal: start the server, exchange RTC, receive websocket events
@@ -185,7 +183,7 @@ async def test_transcript_rtc_and_websocket(
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
- assert ev["data"]["title"] == "LLM TITLE"
+ assert ev["data"]["title"] == "Llm Title"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
@@ -228,8 +226,6 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_diarization,
dummy_storage,
fake_mp3_upload,
- ensure_casing,
- nltk,
appserver,
):
# goal: start the server, exchange RTC, receive websocket events
@@ -353,7 +349,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
- assert ev["data"]["title"] == "LLM TITLE"
+ assert ev["data"]["title"] == "Llm Title"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
diff --git a/server/tests/test_transcripts_upload.py b/server/tests/test_transcripts_upload.py
index fab21321..1bc82386 100644
--- a/server/tests/test_transcripts_upload.py
+++ b/server/tests/test_transcripts_upload.py
@@ -10,7 +10,6 @@ from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,
- ensure_casing,
dummy_llm,
dummy_processors,
dummy_diarization,
@@ -53,7 +52,7 @@ async def test_transcript_upload_file(
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
- assert transcript["title"] == "LLM TITLE"
+ assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
diff --git a/server/tests/test_utils_text.py b/server/tests/test_utils_text.py
new file mode 100644
index 00000000..8b86fe0a
--- /dev/null
+++ b/server/tests/test_utils_text.py
@@ -0,0 +1,21 @@
+import pytest
+
+from reflector.utils.text import clean_title
+
+
+@pytest.mark.parametrize(
+ "input_title,expected",
+ [
+ ("hello world", "Hello World"),
+ ("HELLO WORLD", "Hello World"),
+ ("hello WORLD", "Hello World"),
+ ("the quick brown fox", "The Quick Brown fox"),
+ ("discussion about API design", "Discussion About api Design"),
+ ("Q1 2024 budget review", "Q1 2024 Budget Review"),
+ ("'Title with quotes'", "Title With Quotes"),
+ ("'title with quotes'", "Title With Quotes"),
+ ("MiXeD CaSe WoRdS", "Mixed Case Words"),
+ ],
+)
+def test_clean_title(input_title, expected):
+ assert clean_title(input_title) == expected