mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 20:59:05 +00:00
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
import httpx
|
|
from transformers import AutoTokenizer, GenerationConfig
|
|
|
|
from reflector.llm.base import LLM
|
|
from reflector.logger import logger as reflector_logger
|
|
from reflector.settings import settings
|
|
from reflector.utils.retry import retry
|
|
|
|
|
|
class ModalLLM(LLM):
|
|
def __init__(self, model_name: str | None = None):
|
|
super().__init__()
|
|
self.timeout = settings.LLM_TIMEOUT
|
|
self.llm_url = settings.LLM_URL + "/llm"
|
|
self.llm_warmup_url = settings.LLM_URL + "/warmup"
|
|
self.headers = {
|
|
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
|
}
|
|
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
|
|
|
@property
|
|
def supported_models(self):
|
|
"""
|
|
List of currently supported models on this GPU platform
|
|
"""
|
|
# TODO: Query the specific GPU platform
|
|
# Replace this with a HTTP call
|
|
return ["lmsys/vicuna-13b-v1.5"]
|
|
|
|
async def _warmup(self, logger):
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
self.llm_warmup_url,
|
|
headers=self.headers,
|
|
timeout=60 * 5,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
async def _generate(
|
|
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
|
):
|
|
json_payload = {"prompt": prompt}
|
|
if gen_schema:
|
|
json_payload["gen_schema"] = gen_schema
|
|
if gen_cfg:
|
|
json_payload["gen_cfg"] = gen_cfg
|
|
async with httpx.AsyncClient() as client:
|
|
response = await retry(client.post)(
|
|
self.llm_url,
|
|
headers=self.headers,
|
|
json=json_payload,
|
|
timeout=self.timeout,
|
|
retry_timeout=60 * 5,
|
|
)
|
|
response.raise_for_status()
|
|
text = response.json()["text"]
|
|
return text
|
|
|
|
def _set_model_name(self, model_name: str) -> bool:
|
|
"""
|
|
Set the model name
|
|
"""
|
|
# Abort, if the model is not supported
|
|
if model_name not in self.supported_models:
|
|
reflector_logger.info(
|
|
f"Attempted to change {model_name=}, but is not supported."
|
|
f"Setting model and tokenizer failed !"
|
|
)
|
|
return False
|
|
# Abort, if the model is already set
|
|
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
|
reflector_logger.info("No change in model. Setting model skipped.")
|
|
return False
|
|
# Update model name and tokenizer
|
|
self.model_name = model_name
|
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
|
self.model_name, cache_dir=settings.CACHE_DIR
|
|
)
|
|
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
|
return True
|
|
|
|
def _get_tokenizer(self) -> AutoTokenizer:
|
|
"""
|
|
Return the currently used LLM tokenizer
|
|
"""
|
|
return self.llm_tokenizer
|
|
|
|
def _get_model_name(self) -> str:
|
|
"""
|
|
Return the current model name from the instance details
|
|
"""
|
|
return self.model_name
|
|
|
|
|
|
LLM.register("modal", ModalLLM)
|
|
|
|
if __name__ == "__main__":
|
|
from reflector.logger import logger
|
|
|
|
async def main():
|
|
llm = ModalLLM()
|
|
prompt = llm.create_prompt(
|
|
instruct="Complete the following task",
|
|
text="Tell me a joke about programming.",
|
|
)
|
|
result = await llm.generate(prompt=prompt, logger=logger)
|
|
print(result)
|
|
|
|
gen_schema = {
|
|
"type": "object",
|
|
"properties": {"response": {"type": "string"}},
|
|
}
|
|
|
|
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
|
print(result)
|
|
|
|
gen_cfg = GenerationConfig(max_new_tokens=150)
|
|
result = await llm.generate(
|
|
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
|
)
|
|
print(result)
|
|
|
|
import asyncio
|
|
|
|
asyncio.run(main())
|