72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
"""Simple LLM helper for workbooks using Mirascope v2."""
|
||
|
||
import os
|
||
import re
|
||
from typing import TypeVar
|
||
|
||
from mirascope import llm
|
||
from pydantic import BaseModel
|
||
|
||
T = TypeVar("T", bound=BaseModel)
|
||
|
||
# Configure from environment (defaults match .env.example)
|
||
_api_key = os.getenv("LLM_API_KEY", "")
|
||
_base_url = os.getenv("LLM_API_URL", "https://litellm-notrack.app.monadical.io")
|
||
_model = os.getenv("LLM_MODEL", "GLM-4.5-Air-FP8-dev")
|
||
|
||
# Register our LiteLLM endpoint as an OpenAI-compatible provider
|
||
_base = (_base_url or "").rstrip("/")
|
||
llm.register_provider(
|
||
"openai",
|
||
scope="litellm/",
|
||
base_url=_base if _base.endswith("/v1") else f"{_base}/v1",
|
||
api_key=_api_key,
|
||
)
|
||
|
||
|
||
def _sanitize_json(text: str) -> str:
|
||
"""Strip control characters (U+0000–U+001F) that break JSON parsing.
|
||
|
||
Some LLMs emit literal newlines/tabs inside JSON string values,
|
||
which is invalid per the JSON spec. Replace them with spaces.
|
||
"""
|
||
return re.sub(r"[\x00-\x1f]+", " ", text)
|
||
|
||
|
||
async def llm_call(
|
||
prompt: str,
|
||
response_model: type[T],
|
||
system_prompt: str = "You are a helpful assistant.",
|
||
model: str | None = None,
|
||
) -> T:
|
||
"""Make a structured LLM call.
|
||
|
||
Args:
|
||
prompt: The user prompt
|
||
response_model: Pydantic model for structured output
|
||
system_prompt: System instructions
|
||
model: Override the default model
|
||
|
||
Returns:
|
||
Parsed response matching the response_model schema
|
||
"""
|
||
use_model = model or _model
|
||
|
||
@llm.call(f"litellm/{use_model}", format=response_model)
|
||
async def _call() -> str:
|
||
return f"{system_prompt}\n\n{prompt}"
|
||
|
||
response = await _call()
|
||
try:
|
||
return response.parse()
|
||
except Exception:
|
||
# Fallback: extract content and parse manually
|
||
# response.content could be a string or a list of Text objects
|
||
content = response.content
|
||
if isinstance(content, list):
|
||
# Extract text from list of Text objects
|
||
content = "".join([chunk.text if hasattr(chunk, 'text') else str(chunk) for chunk in content])
|
||
elif not isinstance(content, str):
|
||
content = str(content)
|
||
return response_model.model_validate_json(_sanitize_json(content))
|