From fc819a386185330e60946ee4712f268cfed2b66a Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 8 Aug 2025 15:12:04 -0600 Subject: [PATCH] feat: universal model management for all standard providers (#34) * fix: add crush plugin support too * feat: comprehensive model management for all standard providers - Add universal provider support for model fetching (OpenAI, Anthropic, Google, OpenRouter) - Add default API URLs for standard providers in config.py - Enhance model fetcher with provider-specific authentication: * Anthropic: x-api-key header + anthropic-version header * Google: x-goog-api-key header + custom response format handling * OpenAI/OpenRouter: Bearer token (unchanged) - Support Google's unique API response format (models vs data key, name vs id field) - Update CLI commands to work with all supported provider types - Enhance configure interface to include all providers (even those without API keys) - Update both OpenCode and Crush plugins to populate models for all provider types - Add comprehensive provider support detection methods --- cubbi/cli.py | 24 ++++---- cubbi/config.py | 8 +++ cubbi/configure.py | 16 +++--- cubbi/images/crush/crush_plugin.py | 22 +++++++- cubbi/images/opencode/opencode_plugin.py | 43 ++++++++++----- cubbi/model_fetcher.py | 70 +++++++++++++++++++----- cubbi/user_config.py | 27 +++++++++ 7 files changed, 161 insertions(+), 49 deletions(-) diff --git a/cubbi/cli.py b/cubbi/cli.py index a99d103..d170e91 100644 --- a/cubbi/cli.py +++ b/cubbi/cli.py @@ -2297,12 +2297,12 @@ def refresh_models( console.print(f"[red]Provider '{provider}' not found[/red]") return - if not user_config.is_provider_openai_compatible(provider): + if not user_config.supports_model_fetching(provider): console.print( - f"[red]Provider '{provider}' is not a custom OpenAI provider[/red]" + f"[red]Provider '{provider}' does not support model fetching[/red]" ) console.print( - "Only providers with type='openai' and custom base_url are supported" + "Only providers of supported types (openai, anthropic, google, openrouter) can refresh models" ) return @@ -2328,24 +2328,24 @@ def refresh_models( except Exception as e: console.print(f"[red]Failed to refresh models for '{provider}': {e}[/red]") else: - # Refresh models for all OpenAI-compatible providers - compatible_providers = user_config.list_openai_compatible_providers() + # Refresh models for all model-fetchable providers + fetchable_providers = user_config.list_model_fetchable_providers() - if not compatible_providers: - console.print("[yellow]No custom OpenAI providers found[/yellow]") + if not fetchable_providers: console.print( - "Add providers with type='openai' and custom base_url to refresh models" + "[yellow]No providers with model fetching support found[/yellow]" + ) + console.print( + "Add providers of supported types (openai, anthropic, google, openrouter) to refresh models" ) return - console.print( - f"Refreshing models for {len(compatible_providers)} custom OpenAI providers..." - ) + console.print(f"Refreshing models for {len(fetchable_providers)} providers...") success_count = 0 failed_providers = [] - for provider_name in compatible_providers: + for provider_name in fetchable_providers: try: provider_config = user_config.get_provider(provider_name) with console.status(f"Fetching models from {provider_name}..."): diff --git a/cubbi/config.py b/cubbi/config.py index a9bfa0d..42fe44e 100644 --- a/cubbi/config.py +++ b/cubbi/config.py @@ -14,6 +14,14 @@ BUILTIN_IMAGES_DIR = Path(__file__).parent / "images" # Dynamically loaded from images directory at runtime DEFAULT_IMAGES = {} +# Default API URLs for standard providers +PROVIDER_DEFAULT_URLS = { + "openai": "https://api.openai.com", + "anthropic": "https://api.anthropic.com", + "google": "https://generativelanguage.googleapis.com", + "openrouter": "https://openrouter.ai/api", +} + class ConfigManager: def __init__(self, config_path: Optional[Path] = None): diff --git a/cubbi/configure.py b/cubbi/configure.py index 52ab090..c806724 100644 --- a/cubbi/configure.py +++ b/cubbi/configure.py @@ -222,7 +222,7 @@ class ProviderConfigurator: console.print(f"[green]Added provider '{provider_name}'[/green]") - if self.user_config.is_provider_openai_compatible(provider_name): + if self.user_config.supports_model_fetching(provider_name): console.print("Refreshing models...") try: self._refresh_provider_models(provider_name) @@ -252,10 +252,10 @@ class ProviderConfigurator: model_id = model.get("id", str(model)) else: model_id = str(model) - console.print(f" {i+1}. {model_id}") + console.print(f" {i + 1}. {model_id}") if len(value) > 10: console.print( - f" ... and {len(value)-10} more ({len(value)} total)" + f" ... and {len(value) - 10} more ({len(value)} total)" ) continue else: @@ -268,7 +268,7 @@ class ProviderConfigurator: while True: choices = ["Remove provider"] - if self.user_config.is_provider_openai_compatible(provider_name): + if self.user_config.supports_model_fetching(provider_name): choices.append("Refresh models") choices.extend(["---", "Back"]) @@ -375,7 +375,9 @@ class ProviderConfigurator: for provider_name, provider_config in providers.items(): provider_type = provider_config.get("type", "unknown") has_key = bool(provider_config.get("api_key")) - if has_key: + + # Include provider if it has an API key OR supports model fetching (might not need key) + if has_key or self.user_config.supports_model_fetching(provider_name): base_url = provider_config.get("base_url") if base_url: choices.append(f"{provider_name} ({provider_type}) - {base_url}") @@ -383,7 +385,7 @@ class ProviderConfigurator: choices.append(f"{provider_name} ({provider_type})") if not choices: - console.print("[yellow]No providers with API keys configured.[/yellow]") + console.print("[yellow]No usable providers configured.[/yellow]") return # Add separator and cancel option @@ -401,7 +403,7 @@ class ProviderConfigurator: # Extract provider name provider_name = choice.split(" (")[0] - if self.user_config.is_provider_openai_compatible(provider_name): + if self.user_config.supports_model_fetching(provider_name): model_name = self._select_model_from_list(provider_name) else: model_name = questionary.text( diff --git a/cubbi/images/crush/crush_plugin.py b/cubbi/images/crush/crush_plugin.py index f1e39f6..f38488e 100644 --- a/cubbi/images/crush/crush_plugin.py +++ b/cubbi/images/crush/crush_plugin.py @@ -27,17 +27,37 @@ class CrushPlugin(ToolPlugin): def _map_provider_to_crush_format( self, provider_name: str, provider_config, is_default_provider: bool = False ) -> dict[str, Any] | None: + # Handle standard providers without base_url if not provider_config.base_url: if provider_config.type in STANDARD_PROVIDERS: + # Populate models for any standard provider that has models + models_list = [] + if provider_config.models: + for model in provider_config.models: + model_id = model.get("id", "") + if model_id: + models_list.append({"id": model_id, "name": model_id}) + provider_entry = { "api_key": provider_config.api_key, + "models": models_list, } return provider_entry + # Handle custom providers with base_url + models_list = [] + + # Add all models for any provider type that has models + if provider_config.models: + for model in provider_config.models: + model_id = model.get("id", "") + if model_id: + models_list.append({"id": model_id, "name": model_id}) + provider_entry = { "api_key": provider_config.api_key, "base_url": provider_config.base_url, - "models": [], + "models": models_list, } if provider_config.type in STANDARD_PROVIDERS: diff --git a/cubbi/images/opencode/opencode_plugin.py b/cubbi/images/opencode/opencode_plugin.py index e8fe527..0066edb 100644 --- a/cubbi/images/opencode/opencode_plugin.py +++ b/cubbi/images/opencode/opencode_plugin.py @@ -53,8 +53,8 @@ class OpencodePlugin(ToolPlugin): # Custom provider - include baseURL and name models_dict = {} - # Add all models for OpenAI-compatible providers - if provider_config.type == "openai" and provider_config.models: + # Add all models for any provider type that has models + if provider_config.models: for model in provider_config.models: model_id = model.get("id", "") if model_id: @@ -77,10 +77,6 @@ class OpencodePlugin(ToolPlugin): elif provider_config.type == "openai": provider_entry["npm"] = "@ai-sdk/openai-compatible" provider_entry["name"] = f"OpenAI Compatible ({provider_name})" - if models_dict: - self.status.log( - f"Added {len(models_dict)} models to {provider_name}" - ) elif provider_config.type == "google": provider_entry["npm"] = "@ai-sdk/google" provider_entry["name"] = f"Google ({provider_name})" @@ -93,19 +89,38 @@ class OpencodePlugin(ToolPlugin): provider_entry["name"] = provider_name.title() config_data["provider"][provider_name] = provider_entry - self.status.log( - f"Added {provider_name} custom provider to OpenCode configuration" - ) + if models_dict: + self.status.log( + f"Added {provider_name} custom provider with {len(models_dict)} models to OpenCode configuration" + ) + else: + self.status.log( + f"Added {provider_name} custom provider to OpenCode configuration" + ) else: - # Standard provider without custom URL - minimal config + # Standard provider without custom URL if provider_config.type in STANDARD_PROVIDERS: + # Populate models for any provider that has models + models_dict = {} + if provider_config.models: + for model in provider_config.models: + model_id = model.get("id", "") + if model_id: + models_dict[model_id] = {"name": model_id} + config_data["provider"][provider_name] = { "options": {"apiKey": provider_config.api_key}, - "models": {}, + "models": models_dict, } - self.status.log( - f"Added {provider_name} standard provider to OpenCode configuration" - ) + + if models_dict: + self.status.log( + f"Added {provider_name} standard provider with {len(models_dict)} models to OpenCode configuration" + ) + else: + self.status.log( + f"Added {provider_name} standard provider to OpenCode configuration" + ) # Set default model if cubbi_config.defaults.model: diff --git a/cubbi/model_fetcher.py b/cubbi/model_fetcher.py index c2da484..f8521e7 100644 --- a/cubbi/model_fetcher.py +++ b/cubbi/model_fetcher.py @@ -27,6 +27,7 @@ class ModelFetcher: base_url: str, api_key: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + provider_type: Optional[str] = None, ) -> List[Dict[str, str]]: """Fetch models from an OpenAI-compatible /v1/models endpoint. @@ -34,6 +35,7 @@ class ModelFetcher: base_url: Base URL of the provider (e.g., "https://api.openai.com" or "https://api.litellm.com") api_key: Optional API key for authentication headers: Optional additional headers + provider_type: Optional provider type for authentication handling Returns: List of model dictionaries with 'id' and 'name' keys @@ -46,7 +48,7 @@ class ModelFetcher: models_url = self._build_models_url(base_url) # Prepare headers - request_headers = self._build_headers(api_key, headers) + request_headers = self._build_headers(api_key, headers, provider_type) logger.info(f"Fetching models from {models_url}") @@ -59,13 +61,22 @@ class ModelFetcher: # Parse JSON response data = response.json() - # Validate response structure - if not isinstance(data, dict) or "data" not in data: - raise ValueError( - f"Invalid response format: expected dict with 'data' key, got {type(data)}" - ) + # Handle provider-specific response formats + if provider_type == "google": + # Google uses {"models": [...]} format + if not isinstance(data, dict) or "models" not in data: + raise ValueError( + f"Invalid Google response format: expected dict with 'models' key, got {type(data)}" + ) + models_data = data["models"] + else: + # OpenAI-compatible format uses {"data": [...]} + if not isinstance(data, dict) or "data" not in data: + raise ValueError( + f"Invalid response format: expected dict with 'data' key, got {type(data)}" + ) + models_data = data["data"] - models_data = data["data"] if not isinstance(models_data, list): raise ValueError( f"Invalid models data: expected list, got {type(models_data)}" @@ -77,7 +88,14 @@ class ModelFetcher: if not isinstance(model_item, dict): continue - model_id = model_item.get("id", "") + # Handle provider-specific model ID fields + if provider_type == "google": + # Google uses "name" field (e.g., "models/gemini-1.5-pro") + model_id = model_item.get("name", "") + else: + # OpenAI-compatible uses "id" field + model_id = model_item.get("id", "") + if not model_id: continue @@ -144,12 +162,14 @@ class ModelFetcher: self, api_key: Optional[str] = None, additional_headers: Optional[Dict[str, str]] = None, + provider_type: Optional[str] = None, ) -> Dict[str, str]: """Build request headers. Args: api_key: Optional API key for authentication additional_headers: Optional additional headers + provider_type: Provider type for specific auth handling Returns: Dictionary of headers @@ -161,7 +181,15 @@ class ModelFetcher: # Add authentication header if API key is provided if api_key: - headers["Authorization"] = f"Bearer {api_key}" + if provider_type == "anthropic": + # Anthropic uses x-api-key header + headers["x-api-key"] = api_key + elif provider_type == "google": + # Google uses x-goog-api-key header + headers["x-goog-api-key"] = api_key + else: + # Standard Bearer token for OpenAI, OpenRouter, and custom providers + headers["Authorization"] = f"Bearer {api_key}" # Add any additional headers if additional_headers: @@ -183,26 +211,38 @@ def fetch_provider_models( List of model dictionaries Raises: - ValueError: If provider is not OpenAI-compatible or missing required fields + ValueError: If provider is not supported or missing required fields requests.RequestException: If the request fails """ import os + from .config import PROVIDER_DEFAULT_URLS provider_type = provider_config.get("type", "") base_url = provider_config.get("base_url") api_key = provider_config.get("api_key", "") + # Resolve environment variables in API key if api_key.startswith("${") and api_key.endswith("}"): env_var_name = api_key[2:-1] api_key = os.environ.get(env_var_name, "") - if provider_type != "openai" and not base_url: + # Determine base URL - use custom base_url or default for standard providers + if base_url: + # Custom provider with explicit base_url + effective_base_url = base_url + elif provider_type in PROVIDER_DEFAULT_URLS: + # Standard provider - use default URL + effective_base_url = PROVIDER_DEFAULT_URLS[provider_type] + else: raise ValueError( - "Provider is not OpenAI-compatible (must have type='openai' or base_url)" + f"Unsupported provider type '{provider_type}'. Must be one of: {list(PROVIDER_DEFAULT_URLS.keys())} or have a custom base_url" ) - if not base_url: - raise ValueError("No base_url specified for OpenAI-compatible provider") + # Prepare additional headers for specific providers + headers = {} + if provider_type == "anthropic": + # Anthropic uses a different API version header + headers["anthropic-version"] = "2023-06-01" fetcher = ModelFetcher(timeout=timeout) - return fetcher.fetch_models(base_url, api_key) + return fetcher.fetch_models(effective_base_url, api_key, headers, provider_type) diff --git a/cubbi/user_config.py b/cubbi/user_config.py index ce75a4b..6ae0aed 100644 --- a/cubbi/user_config.py +++ b/cubbi/user_config.py @@ -736,6 +736,22 @@ class UserConfigManager: provider_type = provider_config.get("type", "") return provider_type == "openai" and provider_config.get("base_url") is not None + def supports_model_fetching(self, provider_name: str) -> bool: + """Check if a provider supports model fetching via API.""" + from .config import PROVIDER_DEFAULT_URLS + + provider = self.get_provider(provider_name) + if not provider: + return False + + provider_type = provider.get("type") + base_url = provider.get("base_url") + + # Provider supports model fetching if: + # 1. It has a custom base_url (OpenAI-compatible), OR + # 2. It's a standard provider type that we support + return base_url is not None or provider_type in PROVIDER_DEFAULT_URLS + def list_openai_compatible_providers(self) -> List[str]: providers = self.list_providers() compatible_providers = [] @@ -745,3 +761,14 @@ class UserConfigManager: compatible_providers.append(provider_name) return compatible_providers + + def list_model_fetchable_providers(self) -> List[str]: + """List all providers that support model fetching.""" + providers = self.list_providers() + fetchable_providers = [] + + for provider_name in providers.keys(): + if self.supports_model_fetching(provider_name): + fetchable_providers.append(provider_name) + + return fetchable_providers