feat: universal model management for all standard providers (#34)

* fix: add crush plugin support too

* feat: comprehensive model management for all standard providers

- Add universal provider support for model fetching (OpenAI, Anthropic, Google, OpenRouter)
- Add default API URLs for standard providers in config.py
- Enhance model fetcher with provider-specific authentication:
  * Anthropic: x-api-key header + anthropic-version header
  * Google: x-goog-api-key header + custom response format handling
  * OpenAI/OpenRouter: Bearer token (unchanged)
- Support Google's unique API response format (models vs data key, name vs id field)
- Update CLI commands to work with all supported provider types
- Enhance configure interface to include all providers (even those without API keys)
- Update both OpenCode and Crush plugins to populate models for all provider types
- Add comprehensive provider support detection methods
This commit is contained in:
2025-08-08 15:12:04 -06:00
committed by GitHub
parent 7d6bc5dbfa
commit fc819a3861
7 changed files with 161 additions and 49 deletions

View File

@@ -2297,12 +2297,12 @@ def refresh_models(
console.print(f"[red]Provider '{provider}' not found[/red]")
return
if not user_config.is_provider_openai_compatible(provider):
if not user_config.supports_model_fetching(provider):
console.print(
f"[red]Provider '{provider}' is not a custom OpenAI provider[/red]"
f"[red]Provider '{provider}' does not support model fetching[/red]"
)
console.print(
"Only providers with type='openai' and custom base_url are supported"
"Only providers of supported types (openai, anthropic, google, openrouter) can refresh models"
)
return
@@ -2328,24 +2328,24 @@ def refresh_models(
except Exception as e:
console.print(f"[red]Failed to refresh models for '{provider}': {e}[/red]")
else:
# Refresh models for all OpenAI-compatible providers
compatible_providers = user_config.list_openai_compatible_providers()
# Refresh models for all model-fetchable providers
fetchable_providers = user_config.list_model_fetchable_providers()
if not compatible_providers:
console.print("[yellow]No custom OpenAI providers found[/yellow]")
if not fetchable_providers:
console.print(
"Add providers with type='openai' and custom base_url to refresh models"
"[yellow]No providers with model fetching support found[/yellow]"
)
console.print(
"Add providers of supported types (openai, anthropic, google, openrouter) to refresh models"
)
return
console.print(
f"Refreshing models for {len(compatible_providers)} custom OpenAI providers..."
)
console.print(f"Refreshing models for {len(fetchable_providers)} providers...")
success_count = 0
failed_providers = []
for provider_name in compatible_providers:
for provider_name in fetchable_providers:
try:
provider_config = user_config.get_provider(provider_name)
with console.status(f"Fetching models from {provider_name}..."):

View File

@@ -14,6 +14,14 @@ BUILTIN_IMAGES_DIR = Path(__file__).parent / "images"
# Dynamically loaded from images directory at runtime
DEFAULT_IMAGES = {}
# Default API URLs for standard providers
PROVIDER_DEFAULT_URLS = {
"openai": "https://api.openai.com",
"anthropic": "https://api.anthropic.com",
"google": "https://generativelanguage.googleapis.com",
"openrouter": "https://openrouter.ai/api",
}
class ConfigManager:
def __init__(self, config_path: Optional[Path] = None):

View File

@@ -222,7 +222,7 @@ class ProviderConfigurator:
console.print(f"[green]Added provider '{provider_name}'[/green]")
if self.user_config.is_provider_openai_compatible(provider_name):
if self.user_config.supports_model_fetching(provider_name):
console.print("Refreshing models...")
try:
self._refresh_provider_models(provider_name)
@@ -268,7 +268,7 @@ class ProviderConfigurator:
while True:
choices = ["Remove provider"]
if self.user_config.is_provider_openai_compatible(provider_name):
if self.user_config.supports_model_fetching(provider_name):
choices.append("Refresh models")
choices.extend(["---", "Back"])
@@ -375,7 +375,9 @@ class ProviderConfigurator:
for provider_name, provider_config in providers.items():
provider_type = provider_config.get("type", "unknown")
has_key = bool(provider_config.get("api_key"))
if has_key:
# Include provider if it has an API key OR supports model fetching (might not need key)
if has_key or self.user_config.supports_model_fetching(provider_name):
base_url = provider_config.get("base_url")
if base_url:
choices.append(f"{provider_name} ({provider_type}) - {base_url}")
@@ -383,7 +385,7 @@ class ProviderConfigurator:
choices.append(f"{provider_name} ({provider_type})")
if not choices:
console.print("[yellow]No providers with API keys configured.[/yellow]")
console.print("[yellow]No usable providers configured.[/yellow]")
return
# Add separator and cancel option
@@ -401,7 +403,7 @@ class ProviderConfigurator:
# Extract provider name
provider_name = choice.split(" (")[0]
if self.user_config.is_provider_openai_compatible(provider_name):
if self.user_config.supports_model_fetching(provider_name):
model_name = self._select_model_from_list(provider_name)
else:
model_name = questionary.text(

View File

@@ -27,17 +27,37 @@ class CrushPlugin(ToolPlugin):
def _map_provider_to_crush_format(
self, provider_name: str, provider_config, is_default_provider: bool = False
) -> dict[str, Any] | None:
# Handle standard providers without base_url
if not provider_config.base_url:
if provider_config.type in STANDARD_PROVIDERS:
# Populate models for any standard provider that has models
models_list = []
if provider_config.models:
for model in provider_config.models:
model_id = model.get("id", "")
if model_id:
models_list.append({"id": model_id, "name": model_id})
provider_entry = {
"api_key": provider_config.api_key,
"models": models_list,
}
return provider_entry
# Handle custom providers with base_url
models_list = []
# Add all models for any provider type that has models
if provider_config.models:
for model in provider_config.models:
model_id = model.get("id", "")
if model_id:
models_list.append({"id": model_id, "name": model_id})
provider_entry = {
"api_key": provider_config.api_key,
"base_url": provider_config.base_url,
"models": [],
"models": models_list,
}
if provider_config.type in STANDARD_PROVIDERS:

View File

@@ -53,8 +53,8 @@ class OpencodePlugin(ToolPlugin):
# Custom provider - include baseURL and name
models_dict = {}
# Add all models for OpenAI-compatible providers
if provider_config.type == "openai" and provider_config.models:
# Add all models for any provider type that has models
if provider_config.models:
for model in provider_config.models:
model_id = model.get("id", "")
if model_id:
@@ -77,10 +77,6 @@ class OpencodePlugin(ToolPlugin):
elif provider_config.type == "openai":
provider_entry["npm"] = "@ai-sdk/openai-compatible"
provider_entry["name"] = f"OpenAI Compatible ({provider_name})"
if models_dict:
self.status.log(
f"Added {len(models_dict)} models to {provider_name}"
)
elif provider_config.type == "google":
provider_entry["npm"] = "@ai-sdk/google"
provider_entry["name"] = f"Google ({provider_name})"
@@ -93,16 +89,35 @@ class OpencodePlugin(ToolPlugin):
provider_entry["name"] = provider_name.title()
config_data["provider"][provider_name] = provider_entry
if models_dict:
self.status.log(
f"Added {provider_name} custom provider with {len(models_dict)} models to OpenCode configuration"
)
else:
self.status.log(
f"Added {provider_name} custom provider to OpenCode configuration"
)
else:
# Standard provider without custom URL - minimal config
# Standard provider without custom URL
if provider_config.type in STANDARD_PROVIDERS:
# Populate models for any provider that has models
models_dict = {}
if provider_config.models:
for model in provider_config.models:
model_id = model.get("id", "")
if model_id:
models_dict[model_id] = {"name": model_id}
config_data["provider"][provider_name] = {
"options": {"apiKey": provider_config.api_key},
"models": {},
"models": models_dict,
}
if models_dict:
self.status.log(
f"Added {provider_name} standard provider with {len(models_dict)} models to OpenCode configuration"
)
else:
self.status.log(
f"Added {provider_name} standard provider to OpenCode configuration"
)

View File

@@ -27,6 +27,7 @@ class ModelFetcher:
base_url: str,
api_key: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
provider_type: Optional[str] = None,
) -> List[Dict[str, str]]:
"""Fetch models from an OpenAI-compatible /v1/models endpoint.
@@ -34,6 +35,7 @@ class ModelFetcher:
base_url: Base URL of the provider (e.g., "https://api.openai.com" or "https://api.litellm.com")
api_key: Optional API key for authentication
headers: Optional additional headers
provider_type: Optional provider type for authentication handling
Returns:
List of model dictionaries with 'id' and 'name' keys
@@ -46,7 +48,7 @@ class ModelFetcher:
models_url = self._build_models_url(base_url)
# Prepare headers
request_headers = self._build_headers(api_key, headers)
request_headers = self._build_headers(api_key, headers, provider_type)
logger.info(f"Fetching models from {models_url}")
@@ -59,13 +61,22 @@ class ModelFetcher:
# Parse JSON response
data = response.json()
# Validate response structure
# Handle provider-specific response formats
if provider_type == "google":
# Google uses {"models": [...]} format
if not isinstance(data, dict) or "models" not in data:
raise ValueError(
f"Invalid Google response format: expected dict with 'models' key, got {type(data)}"
)
models_data = data["models"]
else:
# OpenAI-compatible format uses {"data": [...]}
if not isinstance(data, dict) or "data" not in data:
raise ValueError(
f"Invalid response format: expected dict with 'data' key, got {type(data)}"
)
models_data = data["data"]
if not isinstance(models_data, list):
raise ValueError(
f"Invalid models data: expected list, got {type(models_data)}"
@@ -77,7 +88,14 @@ class ModelFetcher:
if not isinstance(model_item, dict):
continue
# Handle provider-specific model ID fields
if provider_type == "google":
# Google uses "name" field (e.g., "models/gemini-1.5-pro")
model_id = model_item.get("name", "")
else:
# OpenAI-compatible uses "id" field
model_id = model_item.get("id", "")
if not model_id:
continue
@@ -144,12 +162,14 @@ class ModelFetcher:
self,
api_key: Optional[str] = None,
additional_headers: Optional[Dict[str, str]] = None,
provider_type: Optional[str] = None,
) -> Dict[str, str]:
"""Build request headers.
Args:
api_key: Optional API key for authentication
additional_headers: Optional additional headers
provider_type: Provider type for specific auth handling
Returns:
Dictionary of headers
@@ -161,6 +181,14 @@ class ModelFetcher:
# Add authentication header if API key is provided
if api_key:
if provider_type == "anthropic":
# Anthropic uses x-api-key header
headers["x-api-key"] = api_key
elif provider_type == "google":
# Google uses x-goog-api-key header
headers["x-goog-api-key"] = api_key
else:
# Standard Bearer token for OpenAI, OpenRouter, and custom providers
headers["Authorization"] = f"Bearer {api_key}"
# Add any additional headers
@@ -183,26 +211,38 @@ def fetch_provider_models(
List of model dictionaries
Raises:
ValueError: If provider is not OpenAI-compatible or missing required fields
ValueError: If provider is not supported or missing required fields
requests.RequestException: If the request fails
"""
import os
from .config import PROVIDER_DEFAULT_URLS
provider_type = provider_config.get("type", "")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key", "")
# Resolve environment variables in API key
if api_key.startswith("${") and api_key.endswith("}"):
env_var_name = api_key[2:-1]
api_key = os.environ.get(env_var_name, "")
if provider_type != "openai" and not base_url:
# Determine base URL - use custom base_url or default for standard providers
if base_url:
# Custom provider with explicit base_url
effective_base_url = base_url
elif provider_type in PROVIDER_DEFAULT_URLS:
# Standard provider - use default URL
effective_base_url = PROVIDER_DEFAULT_URLS[provider_type]
else:
raise ValueError(
"Provider is not OpenAI-compatible (must have type='openai' or base_url)"
f"Unsupported provider type '{provider_type}'. Must be one of: {list(PROVIDER_DEFAULT_URLS.keys())} or have a custom base_url"
)
if not base_url:
raise ValueError("No base_url specified for OpenAI-compatible provider")
# Prepare additional headers for specific providers
headers = {}
if provider_type == "anthropic":
# Anthropic uses a different API version header
headers["anthropic-version"] = "2023-06-01"
fetcher = ModelFetcher(timeout=timeout)
return fetcher.fetch_models(base_url, api_key)
return fetcher.fetch_models(effective_base_url, api_key, headers, provider_type)

View File

@@ -736,6 +736,22 @@ class UserConfigManager:
provider_type = provider_config.get("type", "")
return provider_type == "openai" and provider_config.get("base_url") is not None
def supports_model_fetching(self, provider_name: str) -> bool:
"""Check if a provider supports model fetching via API."""
from .config import PROVIDER_DEFAULT_URLS
provider = self.get_provider(provider_name)
if not provider:
return False
provider_type = provider.get("type")
base_url = provider.get("base_url")
# Provider supports model fetching if:
# 1. It has a custom base_url (OpenAI-compatible), OR
# 2. It's a standard provider type that we support
return base_url is not None or provider_type in PROVIDER_DEFAULT_URLS
def list_openai_compatible_providers(self) -> List[str]:
providers = self.list_providers()
compatible_providers = []
@@ -745,3 +761,14 @@ class UserConfigManager:
compatible_providers.append(provider_name)
return compatible_providers
def list_model_fetchable_providers(self) -> List[str]:
"""List all providers that support model fetching."""
providers = self.list_providers()
fetchable_providers = []
for provider_name in providers.keys():
if self.supports_model_fetching(provider_name):
fetchable_providers.append(provider_name)
return fetchable_providers