mirror of
https://github.com/Monadical-SAS/cubbi.git
synced 2025-12-20 04:09:06 +00:00
feat: universal model management for all standard providers (#34)
* fix: add crush plugin support too * feat: comprehensive model management for all standard providers - Add universal provider support for model fetching (OpenAI, Anthropic, Google, OpenRouter) - Add default API URLs for standard providers in config.py - Enhance model fetcher with provider-specific authentication: * Anthropic: x-api-key header + anthropic-version header * Google: x-goog-api-key header + custom response format handling * OpenAI/OpenRouter: Bearer token (unchanged) - Support Google's unique API response format (models vs data key, name vs id field) - Update CLI commands to work with all supported provider types - Enhance configure interface to include all providers (even those without API keys) - Update both OpenCode and Crush plugins to populate models for all provider types - Add comprehensive provider support detection methods
This commit is contained in:
24
cubbi/cli.py
24
cubbi/cli.py
@@ -2297,12 +2297,12 @@ def refresh_models(
|
||||
console.print(f"[red]Provider '{provider}' not found[/red]")
|
||||
return
|
||||
|
||||
if not user_config.is_provider_openai_compatible(provider):
|
||||
if not user_config.supports_model_fetching(provider):
|
||||
console.print(
|
||||
f"[red]Provider '{provider}' is not a custom OpenAI provider[/red]"
|
||||
f"[red]Provider '{provider}' does not support model fetching[/red]"
|
||||
)
|
||||
console.print(
|
||||
"Only providers with type='openai' and custom base_url are supported"
|
||||
"Only providers of supported types (openai, anthropic, google, openrouter) can refresh models"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -2328,24 +2328,24 @@ def refresh_models(
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to refresh models for '{provider}': {e}[/red]")
|
||||
else:
|
||||
# Refresh models for all OpenAI-compatible providers
|
||||
compatible_providers = user_config.list_openai_compatible_providers()
|
||||
# Refresh models for all model-fetchable providers
|
||||
fetchable_providers = user_config.list_model_fetchable_providers()
|
||||
|
||||
if not compatible_providers:
|
||||
console.print("[yellow]No custom OpenAI providers found[/yellow]")
|
||||
if not fetchable_providers:
|
||||
console.print(
|
||||
"Add providers with type='openai' and custom base_url to refresh models"
|
||||
"[yellow]No providers with model fetching support found[/yellow]"
|
||||
)
|
||||
console.print(
|
||||
"Add providers of supported types (openai, anthropic, google, openrouter) to refresh models"
|
||||
)
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"Refreshing models for {len(compatible_providers)} custom OpenAI providers..."
|
||||
)
|
||||
console.print(f"Refreshing models for {len(fetchable_providers)} providers...")
|
||||
|
||||
success_count = 0
|
||||
failed_providers = []
|
||||
|
||||
for provider_name in compatible_providers:
|
||||
for provider_name in fetchable_providers:
|
||||
try:
|
||||
provider_config = user_config.get_provider(provider_name)
|
||||
with console.status(f"Fetching models from {provider_name}..."):
|
||||
|
||||
@@ -14,6 +14,14 @@ BUILTIN_IMAGES_DIR = Path(__file__).parent / "images"
|
||||
# Dynamically loaded from images directory at runtime
|
||||
DEFAULT_IMAGES = {}
|
||||
|
||||
# Default API URLs for standard providers
|
||||
PROVIDER_DEFAULT_URLS = {
|
||||
"openai": "https://api.openai.com",
|
||||
"anthropic": "https://api.anthropic.com",
|
||||
"google": "https://generativelanguage.googleapis.com",
|
||||
"openrouter": "https://openrouter.ai/api",
|
||||
}
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
|
||||
@@ -222,7 +222,7 @@ class ProviderConfigurator:
|
||||
|
||||
console.print(f"[green]Added provider '{provider_name}'[/green]")
|
||||
|
||||
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||
if self.user_config.supports_model_fetching(provider_name):
|
||||
console.print("Refreshing models...")
|
||||
try:
|
||||
self._refresh_provider_models(provider_name)
|
||||
@@ -252,10 +252,10 @@ class ProviderConfigurator:
|
||||
model_id = model.get("id", str(model))
|
||||
else:
|
||||
model_id = str(model)
|
||||
console.print(f" {i+1}. {model_id}")
|
||||
console.print(f" {i + 1}. {model_id}")
|
||||
if len(value) > 10:
|
||||
console.print(
|
||||
f" ... and {len(value)-10} more ({len(value)} total)"
|
||||
f" ... and {len(value) - 10} more ({len(value)} total)"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
@@ -268,7 +268,7 @@ class ProviderConfigurator:
|
||||
while True:
|
||||
choices = ["Remove provider"]
|
||||
|
||||
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||
if self.user_config.supports_model_fetching(provider_name):
|
||||
choices.append("Refresh models")
|
||||
|
||||
choices.extend(["---", "Back"])
|
||||
@@ -375,7 +375,9 @@ class ProviderConfigurator:
|
||||
for provider_name, provider_config in providers.items():
|
||||
provider_type = provider_config.get("type", "unknown")
|
||||
has_key = bool(provider_config.get("api_key"))
|
||||
if has_key:
|
||||
|
||||
# Include provider if it has an API key OR supports model fetching (might not need key)
|
||||
if has_key or self.user_config.supports_model_fetching(provider_name):
|
||||
base_url = provider_config.get("base_url")
|
||||
if base_url:
|
||||
choices.append(f"{provider_name} ({provider_type}) - {base_url}")
|
||||
@@ -383,7 +385,7 @@ class ProviderConfigurator:
|
||||
choices.append(f"{provider_name} ({provider_type})")
|
||||
|
||||
if not choices:
|
||||
console.print("[yellow]No providers with API keys configured.[/yellow]")
|
||||
console.print("[yellow]No usable providers configured.[/yellow]")
|
||||
return
|
||||
|
||||
# Add separator and cancel option
|
||||
@@ -401,7 +403,7 @@ class ProviderConfigurator:
|
||||
# Extract provider name
|
||||
provider_name = choice.split(" (")[0]
|
||||
|
||||
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||
if self.user_config.supports_model_fetching(provider_name):
|
||||
model_name = self._select_model_from_list(provider_name)
|
||||
else:
|
||||
model_name = questionary.text(
|
||||
|
||||
@@ -27,17 +27,37 @@ class CrushPlugin(ToolPlugin):
|
||||
def _map_provider_to_crush_format(
|
||||
self, provider_name: str, provider_config, is_default_provider: bool = False
|
||||
) -> dict[str, Any] | None:
|
||||
# Handle standard providers without base_url
|
||||
if not provider_config.base_url:
|
||||
if provider_config.type in STANDARD_PROVIDERS:
|
||||
# Populate models for any standard provider that has models
|
||||
models_list = []
|
||||
if provider_config.models:
|
||||
for model in provider_config.models:
|
||||
model_id = model.get("id", "")
|
||||
if model_id:
|
||||
models_list.append({"id": model_id, "name": model_id})
|
||||
|
||||
provider_entry = {
|
||||
"api_key": provider_config.api_key,
|
||||
"models": models_list,
|
||||
}
|
||||
return provider_entry
|
||||
|
||||
# Handle custom providers with base_url
|
||||
models_list = []
|
||||
|
||||
# Add all models for any provider type that has models
|
||||
if provider_config.models:
|
||||
for model in provider_config.models:
|
||||
model_id = model.get("id", "")
|
||||
if model_id:
|
||||
models_list.append({"id": model_id, "name": model_id})
|
||||
|
||||
provider_entry = {
|
||||
"api_key": provider_config.api_key,
|
||||
"base_url": provider_config.base_url,
|
||||
"models": [],
|
||||
"models": models_list,
|
||||
}
|
||||
|
||||
if provider_config.type in STANDARD_PROVIDERS:
|
||||
|
||||
@@ -53,8 +53,8 @@ class OpencodePlugin(ToolPlugin):
|
||||
# Custom provider - include baseURL and name
|
||||
models_dict = {}
|
||||
|
||||
# Add all models for OpenAI-compatible providers
|
||||
if provider_config.type == "openai" and provider_config.models:
|
||||
# Add all models for any provider type that has models
|
||||
if provider_config.models:
|
||||
for model in provider_config.models:
|
||||
model_id = model.get("id", "")
|
||||
if model_id:
|
||||
@@ -77,10 +77,6 @@ class OpencodePlugin(ToolPlugin):
|
||||
elif provider_config.type == "openai":
|
||||
provider_entry["npm"] = "@ai-sdk/openai-compatible"
|
||||
provider_entry["name"] = f"OpenAI Compatible ({provider_name})"
|
||||
if models_dict:
|
||||
self.status.log(
|
||||
f"Added {len(models_dict)} models to {provider_name}"
|
||||
)
|
||||
elif provider_config.type == "google":
|
||||
provider_entry["npm"] = "@ai-sdk/google"
|
||||
provider_entry["name"] = f"Google ({provider_name})"
|
||||
@@ -93,19 +89,38 @@ class OpencodePlugin(ToolPlugin):
|
||||
provider_entry["name"] = provider_name.title()
|
||||
|
||||
config_data["provider"][provider_name] = provider_entry
|
||||
self.status.log(
|
||||
f"Added {provider_name} custom provider to OpenCode configuration"
|
||||
)
|
||||
if models_dict:
|
||||
self.status.log(
|
||||
f"Added {provider_name} custom provider with {len(models_dict)} models to OpenCode configuration"
|
||||
)
|
||||
else:
|
||||
self.status.log(
|
||||
f"Added {provider_name} custom provider to OpenCode configuration"
|
||||
)
|
||||
else:
|
||||
# Standard provider without custom URL - minimal config
|
||||
# Standard provider without custom URL
|
||||
if provider_config.type in STANDARD_PROVIDERS:
|
||||
# Populate models for any provider that has models
|
||||
models_dict = {}
|
||||
if provider_config.models:
|
||||
for model in provider_config.models:
|
||||
model_id = model.get("id", "")
|
||||
if model_id:
|
||||
models_dict[model_id] = {"name": model_id}
|
||||
|
||||
config_data["provider"][provider_name] = {
|
||||
"options": {"apiKey": provider_config.api_key},
|
||||
"models": {},
|
||||
"models": models_dict,
|
||||
}
|
||||
self.status.log(
|
||||
f"Added {provider_name} standard provider to OpenCode configuration"
|
||||
)
|
||||
|
||||
if models_dict:
|
||||
self.status.log(
|
||||
f"Added {provider_name} standard provider with {len(models_dict)} models to OpenCode configuration"
|
||||
)
|
||||
else:
|
||||
self.status.log(
|
||||
f"Added {provider_name} standard provider to OpenCode configuration"
|
||||
)
|
||||
|
||||
# Set default model
|
||||
if cubbi_config.defaults.model:
|
||||
|
||||
@@ -27,6 +27,7 @@ class ModelFetcher:
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
provider_type: Optional[str] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Fetch models from an OpenAI-compatible /v1/models endpoint.
|
||||
|
||||
@@ -34,6 +35,7 @@ class ModelFetcher:
|
||||
base_url: Base URL of the provider (e.g., "https://api.openai.com" or "https://api.litellm.com")
|
||||
api_key: Optional API key for authentication
|
||||
headers: Optional additional headers
|
||||
provider_type: Optional provider type for authentication handling
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
@@ -46,7 +48,7 @@ class ModelFetcher:
|
||||
models_url = self._build_models_url(base_url)
|
||||
|
||||
# Prepare headers
|
||||
request_headers = self._build_headers(api_key, headers)
|
||||
request_headers = self._build_headers(api_key, headers, provider_type)
|
||||
|
||||
logger.info(f"Fetching models from {models_url}")
|
||||
|
||||
@@ -59,13 +61,22 @@ class ModelFetcher:
|
||||
# Parse JSON response
|
||||
data = response.json()
|
||||
|
||||
# Validate response structure
|
||||
if not isinstance(data, dict) or "data" not in data:
|
||||
raise ValueError(
|
||||
f"Invalid response format: expected dict with 'data' key, got {type(data)}"
|
||||
)
|
||||
# Handle provider-specific response formats
|
||||
if provider_type == "google":
|
||||
# Google uses {"models": [...]} format
|
||||
if not isinstance(data, dict) or "models" not in data:
|
||||
raise ValueError(
|
||||
f"Invalid Google response format: expected dict with 'models' key, got {type(data)}"
|
||||
)
|
||||
models_data = data["models"]
|
||||
else:
|
||||
# OpenAI-compatible format uses {"data": [...]}
|
||||
if not isinstance(data, dict) or "data" not in data:
|
||||
raise ValueError(
|
||||
f"Invalid response format: expected dict with 'data' key, got {type(data)}"
|
||||
)
|
||||
models_data = data["data"]
|
||||
|
||||
models_data = data["data"]
|
||||
if not isinstance(models_data, list):
|
||||
raise ValueError(
|
||||
f"Invalid models data: expected list, got {type(models_data)}"
|
||||
@@ -77,7 +88,14 @@ class ModelFetcher:
|
||||
if not isinstance(model_item, dict):
|
||||
continue
|
||||
|
||||
model_id = model_item.get("id", "")
|
||||
# Handle provider-specific model ID fields
|
||||
if provider_type == "google":
|
||||
# Google uses "name" field (e.g., "models/gemini-1.5-pro")
|
||||
model_id = model_item.get("name", "")
|
||||
else:
|
||||
# OpenAI-compatible uses "id" field
|
||||
model_id = model_item.get("id", "")
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
@@ -144,12 +162,14 @@ class ModelFetcher:
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
additional_headers: Optional[Dict[str, str]] = None,
|
||||
provider_type: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Build request headers.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key for authentication
|
||||
additional_headers: Optional additional headers
|
||||
provider_type: Provider type for specific auth handling
|
||||
|
||||
Returns:
|
||||
Dictionary of headers
|
||||
@@ -161,7 +181,15 @@ class ModelFetcher:
|
||||
|
||||
# Add authentication header if API key is provided
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if provider_type == "anthropic":
|
||||
# Anthropic uses x-api-key header
|
||||
headers["x-api-key"] = api_key
|
||||
elif provider_type == "google":
|
||||
# Google uses x-goog-api-key header
|
||||
headers["x-goog-api-key"] = api_key
|
||||
else:
|
||||
# Standard Bearer token for OpenAI, OpenRouter, and custom providers
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
@@ -183,26 +211,38 @@ def fetch_provider_models(
|
||||
List of model dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not OpenAI-compatible or missing required fields
|
||||
ValueError: If provider is not supported or missing required fields
|
||||
requests.RequestException: If the request fails
|
||||
"""
|
||||
import os
|
||||
from .config import PROVIDER_DEFAULT_URLS
|
||||
|
||||
provider_type = provider_config.get("type", "")
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key", "")
|
||||
|
||||
# Resolve environment variables in API key
|
||||
if api_key.startswith("${") and api_key.endswith("}"):
|
||||
env_var_name = api_key[2:-1]
|
||||
api_key = os.environ.get(env_var_name, "")
|
||||
|
||||
if provider_type != "openai" and not base_url:
|
||||
# Determine base URL - use custom base_url or default for standard providers
|
||||
if base_url:
|
||||
# Custom provider with explicit base_url
|
||||
effective_base_url = base_url
|
||||
elif provider_type in PROVIDER_DEFAULT_URLS:
|
||||
# Standard provider - use default URL
|
||||
effective_base_url = PROVIDER_DEFAULT_URLS[provider_type]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Provider is not OpenAI-compatible (must have type='openai' or base_url)"
|
||||
f"Unsupported provider type '{provider_type}'. Must be one of: {list(PROVIDER_DEFAULT_URLS.keys())} or have a custom base_url"
|
||||
)
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("No base_url specified for OpenAI-compatible provider")
|
||||
# Prepare additional headers for specific providers
|
||||
headers = {}
|
||||
if provider_type == "anthropic":
|
||||
# Anthropic uses a different API version header
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
|
||||
fetcher = ModelFetcher(timeout=timeout)
|
||||
return fetcher.fetch_models(base_url, api_key)
|
||||
return fetcher.fetch_models(effective_base_url, api_key, headers, provider_type)
|
||||
|
||||
@@ -736,6 +736,22 @@ class UserConfigManager:
|
||||
provider_type = provider_config.get("type", "")
|
||||
return provider_type == "openai" and provider_config.get("base_url") is not None
|
||||
|
||||
def supports_model_fetching(self, provider_name: str) -> bool:
|
||||
"""Check if a provider supports model fetching via API."""
|
||||
from .config import PROVIDER_DEFAULT_URLS
|
||||
|
||||
provider = self.get_provider(provider_name)
|
||||
if not provider:
|
||||
return False
|
||||
|
||||
provider_type = provider.get("type")
|
||||
base_url = provider.get("base_url")
|
||||
|
||||
# Provider supports model fetching if:
|
||||
# 1. It has a custom base_url (OpenAI-compatible), OR
|
||||
# 2. It's a standard provider type that we support
|
||||
return base_url is not None or provider_type in PROVIDER_DEFAULT_URLS
|
||||
|
||||
def list_openai_compatible_providers(self) -> List[str]:
|
||||
providers = self.list_providers()
|
||||
compatible_providers = []
|
||||
@@ -745,3 +761,14 @@ class UserConfigManager:
|
||||
compatible_providers.append(provider_name)
|
||||
|
||||
return compatible_providers
|
||||
|
||||
def list_model_fetchable_providers(self) -> List[str]:
|
||||
"""List all providers that support model fetching."""
|
||||
providers = self.list_providers()
|
||||
fetchable_providers = []
|
||||
|
||||
for provider_name in providers.keys():
|
||||
if self.supports_model_fetching(provider_name):
|
||||
fetchable_providers.append(provider_name)
|
||||
|
||||
return fetchable_providers
|
||||
|
||||
Reference in New Issue
Block a user