feat: dynamic model management for OpenAI-compatible providers (#33)

feat: add models fetch for openai-compatible endpoint
This commit is contained in:
2025-08-08 12:08:08 -06:00
committed by GitHub
parent 310149dc34
commit 7d6bc5dbfa
9 changed files with 640 additions and 51 deletions

View File

@@ -762,6 +762,10 @@ config_app.add_typer(port_app, name="port", no_args_is_help=True)
config_mcp_app = typer.Typer(help="Manage default MCP servers")
config_app.add_typer(config_mcp_app, name="mcp", no_args_is_help=True)
# Create a models subcommand for config
models_app = typer.Typer(help="Manage provider models")
config_app.add_typer(models_app, name="models", no_args_is_help=True)
# MCP configuration commands
@config_mcp_app.command("list")
@@ -2231,6 +2235,139 @@ exec npm start
console.print("[green]MCP Inspector stopped[/green]")
# Model management commands
@models_app.command("list")
def list_models(
provider: Optional[str] = typer.Argument(None, help="Provider name (optional)"),
) -> None:
if provider:
# List models for specific provider
models = user_config.list_provider_models(provider)
if not models:
if not user_config.get_provider(provider):
console.print(f"[red]Provider '{provider}' not found[/red]")
else:
console.print(f"No models configured for provider '{provider}'")
return
table = Table(show_header=True, header_style="bold")
table.add_column("Model ID")
for model in models:
table.add_row(model["id"])
console.print(f"\n[bold]Models for provider '{provider}'[/bold]")
console.print(table)
else:
# List models for all providers
providers = user_config.list_providers()
if not providers:
console.print("No providers configured")
return
table = Table(show_header=True, header_style="bold")
table.add_column("Provider")
table.add_column("Model ID")
found_models = False
for provider_name in providers.keys():
models = user_config.list_provider_models(provider_name)
for model in models:
table.add_row(provider_name, model["id"])
found_models = True
if found_models:
console.print(table)
else:
console.print("No models configured for any provider")
@models_app.command("refresh")
def refresh_models(
provider: Optional[str] = typer.Argument(None, help="Provider name (optional)"),
) -> None:
from .model_fetcher import fetch_provider_models
if provider:
# Refresh models for specific provider
provider_config = user_config.get_provider(provider)
if not provider_config:
console.print(f"[red]Provider '{provider}' not found[/red]")
return
if not user_config.is_provider_openai_compatible(provider):
console.print(
f"[red]Provider '{provider}' is not a custom OpenAI provider[/red]"
)
console.print(
"Only providers with type='openai' and custom base_url are supported"
)
return
console.print(f"Refreshing models for provider '{provider}'...")
try:
with console.status(f"Fetching models from {provider}..."):
models = fetch_provider_models(provider_config)
user_config.set_provider_models(provider, models)
console.print(
f"[green]Successfully refreshed {len(models)} models for '{provider}'[/green]"
)
# Show some examples
if models:
console.print("\nSample models:")
for model in models[:5]: # Show first 5
console.print(f" - {model['id']}")
if len(models) > 5:
console.print(f" ... and {len(models) - 5} more")
except Exception as e:
console.print(f"[red]Failed to refresh models for '{provider}': {e}[/red]")
else:
# Refresh models for all OpenAI-compatible providers
compatible_providers = user_config.list_openai_compatible_providers()
if not compatible_providers:
console.print("[yellow]No custom OpenAI providers found[/yellow]")
console.print(
"Add providers with type='openai' and custom base_url to refresh models"
)
return
console.print(
f"Refreshing models for {len(compatible_providers)} custom OpenAI providers..."
)
success_count = 0
failed_providers = []
for provider_name in compatible_providers:
try:
provider_config = user_config.get_provider(provider_name)
with console.status(f"Fetching models from {provider_name}..."):
models = fetch_provider_models(provider_config)
user_config.set_provider_models(provider_name, models)
console.print(f"[green]✓ {provider_name}: {len(models)} models[/green]")
success_count += 1
except Exception as e:
console.print(f"[red]✗ {provider_name}: {e}[/red]")
failed_providers.append(provider_name)
# Summary
console.print("\n[bold]Summary[/bold]")
console.print(f"Successfully refreshed: {success_count} providers")
if failed_providers:
console.print(
f"Failed: {len(failed_providers)} providers ({', '.join(failed_providers)})"
)
def session_create_entry_point():
"""Entry point that directly invokes 'cubbi session create'.

View File

@@ -3,6 +3,7 @@ Interactive configuration tool for Cubbi providers and models.
"""
import os
from typing import Optional
import docker
import questionary
@@ -164,7 +165,8 @@ class ProviderConfigurator:
"How would you like to provide the API key?",
choices=[
"Enter API key directly (saved in config)",
"Reference environment variable (recommended)",
"Use environment variable (recommended)",
"No API key needed",
],
).ask()
@@ -184,11 +186,12 @@ class ProviderConfigurator:
api_key = f"${{{env_var.strip()}}}"
# Check if the environment variable exists
if not os.environ.get(env_var.strip()):
console.print(
f"[yellow]Warning: Environment variable '{env_var}' is not currently set[/yellow]"
)
elif "No API key" in api_key_choice:
api_key = ""
else:
api_key = questionary.password(
"Enter API key:",
@@ -219,6 +222,13 @@ class ProviderConfigurator:
console.print(f"[green]Added provider '{provider_name}'[/green]")
if self.user_config.is_provider_openai_compatible(provider_name):
console.print("Refreshing models...")
try:
self._refresh_provider_models(provider_name)
except Exception as e:
console.print(f"[yellow]Could not refresh models: {e}[/yellow]")
def _edit_provider(self, provider_name: str) -> None:
"""Edit an existing provider."""
provider_config = self.user_config.get_provider(provider_name)
@@ -226,36 +236,129 @@ class ProviderConfigurator:
console.print(f"[red]Provider '{provider_name}' not found![/red]")
return
choices = ["View configuration", "Remove provider", "---", "Back"]
choice = questionary.select(
f"What would you like to do with '{provider_name}'?",
choices=choices,
).ask()
if choice == "View configuration":
console.print(f"\n[bold]Configuration for '{provider_name}':[/bold]")
for key, value in provider_config.items():
if key == "api_key" and not value.startswith("${"):
# Mask direct API keys
display_value = (
f"{'*' * (len(value) - 4)}{value[-4:]}"
if len(value) > 4
else "****"
)
console.print(f"\n[bold]Configuration for '{provider_name}':[/bold]")
for key, value in provider_config.items():
if key == "api_key" and not value.startswith("${"):
display_value = (
f"{'*' * (len(value) - 4)}{value[-4:]}"
if len(value) > 4
else "****"
)
elif key == "models" and isinstance(value, list):
if value:
console.print(f" {key}:")
for i, model in enumerate(value[:10]):
if isinstance(model, dict):
model_id = model.get("id", str(model))
else:
model_id = str(model)
console.print(f" {i+1}. {model_id}")
if len(value) > 10:
console.print(
f" ... and {len(value)-10} more ({len(value)} total)"
)
continue
else:
display_value = value
console.print(f" {key}: {display_value}")
console.print()
display_value = "(no models configured)"
else:
display_value = value
console.print(f" {key}: {display_value}")
console.print()
elif choice == "Remove provider":
confirm = questionary.confirm(
f"Are you sure you want to remove provider '{provider_name}'?"
while True:
choices = ["Remove provider"]
if self.user_config.is_provider_openai_compatible(provider_name):
choices.append("Refresh models")
choices.extend(["---", "Back"])
choice = questionary.select(
f"What would you like to do with '{provider_name}'?",
choices=choices,
).ask()
if confirm:
self.user_config.remove_provider(provider_name)
console.print(f"[green]Removed provider '{provider_name}'[/green]")
if choice == "Remove provider":
confirm = questionary.confirm(
f"Are you sure you want to remove provider '{provider_name}'?",
default=False,
).ask()
if confirm:
self.user_config.remove_provider(provider_name)
console.print(f"[green]Removed provider '{provider_name}'[/green]")
break
elif choice == "Refresh models":
self._refresh_provider_models(provider_name)
elif choice == "Back" or choice is None:
break
def _refresh_provider_models(self, provider_name: str) -> None:
from .model_fetcher import fetch_provider_models
try:
provider_config = self.user_config.get_provider(provider_name)
console.print(f"Refreshing models for {provider_name}...")
models = fetch_provider_models(provider_config)
self.user_config.set_provider_models(provider_name, models)
console.print(
f"[green]Successfully refreshed {len(models)} models for '{provider_name}'[/green]"
)
except Exception as e:
console.print(f"[red]Failed to refresh models: {e}[/red]")
def _select_model_from_list(self, provider_name: str) -> Optional[str]:
from .model_fetcher import fetch_provider_models
models = self.user_config.list_provider_models(provider_name)
if not models:
console.print(f"No models found for {provider_name}. Refreshing...")
try:
provider_config = self.user_config.get_provider(provider_name)
models = fetch_provider_models(provider_config)
self.user_config.set_provider_models(provider_name, models)
console.print(f"[green]Refreshed {len(models)} models[/green]")
except Exception as e:
console.print(f"[red]Failed to refresh models: {e}[/red]")
return questionary.text(
f"Enter model name for {provider_name}:",
validate=lambda name: len(name.strip()) > 0
or "Please enter a model name",
).ask()
if not models:
console.print(f"[yellow]No models available for {provider_name}[/yellow]")
return questionary.text(
f"Enter model name for {provider_name}:",
validate=lambda name: len(name.strip()) > 0
or "Please enter a model name",
).ask()
model_choices = [model["id"] for model in models]
model_choices.append("---")
model_choices.append("Enter manually")
choice = questionary.select(
f"Select a model for {provider_name}:",
choices=model_choices,
).ask()
if choice is None or choice == "---":
return None
elif choice == "Enter manually":
return questionary.text(
f"Enter model name for {provider_name}:",
validate=lambda name: len(name.strip()) > 0
or "Please enter a model name",
).ask()
else:
return choice
def _set_default_model(self) -> None:
"""Set the default model."""
@@ -298,16 +401,18 @@ class ProviderConfigurator:
# Extract provider name
provider_name = choice.split(" (")[0]
# Ask for model name
model_name = questionary.text(
f"Enter model name for {provider_name} (e.g., 'claude-3-5-sonnet', 'gpt-4', 'llama3:70b'):",
validate=lambda name: len(name.strip()) > 0 or "Please enter a model name",
).ask()
if self.user_config.is_provider_openai_compatible(provider_name):
model_name = self._select_model_from_list(provider_name)
else:
model_name = questionary.text(
f"Enter model name for {provider_name} (e.g., 'claude-3-5-sonnet', 'gpt-4', 'llama3:70b'):",
validate=lambda name: len(name.strip()) > 0
or "Please enter a model name",
).ask()
if model_name is None:
return
# Set the default model in provider/model format
default_model = f"{provider_name}/{model_name.strip()}"
self.user_config.set("defaults.model", default_model)
@@ -659,17 +764,22 @@ class ProviderConfigurator:
elif "defaults" in choice:
if is_default:
self.user_config.remove_mcp(server_name)
console.print(
f"[green]Removed '{server_name}' from default MCPs[/green]"
)
confirm = questionary.confirm(
f"Remove '{server_name}' from default MCPs?", default=False
).ask()
if confirm:
self.user_config.remove_mcp(server_name)
console.print(
f"[green]Removed '{server_name}' from default MCPs[/green]"
)
else:
self.user_config.add_mcp(server_name)
console.print(f"[green]Added '{server_name}' to default MCPs[/green]")
elif choice == "Remove server":
confirm = questionary.confirm(
f"Are you sure you want to remove MCP server '{server_name}'?"
f"Are you sure you want to remove MCP server '{server_name}'?",
default=False,
).ask()
if confirm:
@@ -749,7 +859,8 @@ class ProviderConfigurator:
elif choice == "Remove network":
confirm = questionary.confirm(
f"Are you sure you want to remove network '{network_name}'?"
f"Are you sure you want to remove network '{network_name}'?",
default=False,
).ask()
if confirm:
@@ -829,7 +940,8 @@ class ProviderConfigurator:
elif choice == "Remove volume":
confirm = questionary.confirm(
f"Are you sure you want to remove volume mapping '{volume_mapping}'?"
f"Are you sure you want to remove volume mapping '{volume_mapping}'?",
default=False,
).ask()
if confirm:
@@ -902,7 +1014,7 @@ class ProviderConfigurator:
if choice == "Remove port":
confirm = questionary.confirm(
f"Are you sure you want to remove port {port_num}?"
f"Are you sure you want to remove port {port_num}?", default=False
).ask()
if confirm:

View File

@@ -116,6 +116,8 @@ class ContainerManager:
}
if provider.get("base_url"):
provider_config["base_url"] = provider.get("base_url")
if provider.get("models"):
provider_config["models"] = provider.get("models")
providers[name] = provider_config

View File

@@ -46,6 +46,7 @@ class ProviderConfig(BaseModel):
type: str
api_key: str
base_url: str | None = None
models: list[dict[str, str]] = []
class MCPConfig(BaseModel):

View File

@@ -51,12 +51,21 @@ class OpencodePlugin(ToolPlugin):
# Check if this is a custom provider (has baseURL)
if provider_config.base_url:
# Custom provider - include baseURL and name
models_dict = {}
# Add all models for OpenAI-compatible providers
if provider_config.type == "openai" and provider_config.models:
for model in provider_config.models:
model_id = model.get("id", "")
if model_id:
models_dict[model_id] = {"name": model_id}
provider_entry: dict[str, str | dict[str, str]] = {
"options": {
"apiKey": provider_config.api_key,
"baseURL": provider_config.base_url,
},
"models": {},
"models": models_dict,
}
# Add npm package and name for custom providers
@@ -68,6 +77,10 @@ class OpencodePlugin(ToolPlugin):
elif provider_config.type == "openai":
provider_entry["npm"] = "@ai-sdk/openai-compatible"
provider_entry["name"] = f"OpenAI Compatible ({provider_name})"
if models_dict:
self.status.log(
f"Added {len(models_dict)} models to {provider_name}"
)
elif provider_config.type == "google":
provider_entry["npm"] = "@ai-sdk/google"
provider_entry["name"] = f"Google ({provider_name})"
@@ -94,22 +107,25 @@ class OpencodePlugin(ToolPlugin):
f"Added {provider_name} standard provider to OpenCode configuration"
)
# Set default model and add it only to the default provider
# Set default model
if cubbi_config.defaults.model:
config_data["model"] = cubbi_config.defaults.model
self.status.log(f"Set default model to {config_data['model']}")
# Add the specific model only to the provider that matches the default model
# Add the default model to provider if it doesn't already have models
provider_name: str
model_name: str
provider_name, model_name = cubbi_config.defaults.model.split("/", 1)
if provider_name in config_data["provider"]:
config_data["provider"][provider_name]["models"] = {
model_name: {"name": model_name}
}
self.status.log(
f"Added default model {model_name} to {provider_name} provider"
)
provider_config = cubbi_config.providers.get(provider_name)
# Only add default model if provider doesn't already have models populated
if not (provider_config and provider_config.models):
config_data["provider"][provider_name]["models"] = {
model_name: {"name": model_name}
}
self.status.log(
f"Added default model {model_name} to {provider_name} provider"
)
else:
# Fallback to legacy environment variables
opencode_model: str | None = os.environ.get("CUBBI_MODEL")

208
cubbi/model_fetcher.py Normal file
View File

@@ -0,0 +1,208 @@
"""
Model fetching utilities for OpenAI-compatible providers.
"""
import json
import logging
from typing import Dict, List, Optional
import requests
logger = logging.getLogger(__name__)
class ModelFetcher:
"""Fetches model lists from OpenAI-compatible API endpoints."""
def __init__(self, timeout: int = 30):
"""Initialize the model fetcher.
Args:
timeout: Request timeout in seconds
"""
self.timeout = timeout
def fetch_models(
self,
base_url: str,
api_key: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Dict[str, str]]:
"""Fetch models from an OpenAI-compatible /v1/models endpoint.
Args:
base_url: Base URL of the provider (e.g., "https://api.openai.com" or "https://api.litellm.com")
api_key: Optional API key for authentication
headers: Optional additional headers
Returns:
List of model dictionaries with 'id' and 'name' keys
Raises:
requests.RequestException: If the request fails
ValueError: If the response format is invalid
"""
# Construct the models endpoint URL
models_url = self._build_models_url(base_url)
# Prepare headers
request_headers = self._build_headers(api_key, headers)
logger.info(f"Fetching models from {models_url}")
try:
response = requests.get(
models_url, headers=request_headers, timeout=self.timeout
)
response.raise_for_status()
# Parse JSON response
data = response.json()
# Validate response structure
if not isinstance(data, dict) or "data" not in data:
raise ValueError(
f"Invalid response format: expected dict with 'data' key, got {type(data)}"
)
models_data = data["data"]
if not isinstance(models_data, list):
raise ValueError(
f"Invalid models data: expected list, got {type(models_data)}"
)
# Process models
models = []
for model_item in models_data:
if not isinstance(model_item, dict):
continue
model_id = model_item.get("id", "")
if not model_id:
continue
# Skip models with * in their ID as requested
if "*" in model_id:
logger.debug(f"Skipping model with wildcard: {model_id}")
continue
# Create model entry
model = {
"id": model_id,
}
models.append(model)
logger.info(f"Successfully fetched {len(models)} models from {base_url}")
return models
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {self.timeout} seconds")
raise requests.RequestException(f"Request to {models_url} timed out")
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {e}")
raise requests.RequestException(f"Failed to connect to {models_url}")
except requests.exceptions.HTTPError as e:
logger.error(f"HTTP error {e.response.status_code}: {e}")
if e.response.status_code == 401:
raise requests.RequestException(
"Authentication failed: invalid API key"
)
elif e.response.status_code == 403:
raise requests.RequestException(
"Access forbidden: check API key permissions"
)
else:
raise requests.RequestException(
f"HTTP {e.response.status_code} error from {models_url}"
)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
raise ValueError(f"Invalid JSON response from {models_url}")
def _build_models_url(self, base_url: str) -> str:
"""Build the models endpoint URL from a base URL.
Args:
base_url: Base URL of the provider
Returns:
Complete URL for the /v1/models endpoint
"""
# Remove trailing slash if present
base_url = base_url.rstrip("/")
# Add /v1/models if not already present
if not base_url.endswith("/v1/models"):
if base_url.endswith("/v1"):
base_url += "/models"
else:
base_url += "/v1/models"
return base_url
def _build_headers(
self,
api_key: Optional[str] = None,
additional_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""Build request headers.
Args:
api_key: Optional API key for authentication
additional_headers: Optional additional headers
Returns:
Dictionary of headers
"""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
# Add authentication header if API key is provided
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Add any additional headers
if additional_headers:
headers.update(additional_headers)
return headers
def fetch_provider_models(
provider_config: Dict, timeout: int = 30
) -> List[Dict[str, str]]:
"""Convenience function to fetch models for a provider configuration.
Args:
provider_config: Provider configuration dictionary
timeout: Request timeout in seconds
Returns:
List of model dictionaries
Raises:
ValueError: If provider is not OpenAI-compatible or missing required fields
requests.RequestException: If the request fails
"""
import os
provider_type = provider_config.get("type", "")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key", "")
if api_key.startswith("${") and api_key.endswith("}"):
env_var_name = api_key[2:-1]
api_key = os.environ.get(env_var_name, "")
if provider_type != "openai" and not base_url:
raise ValueError(
"Provider is not OpenAI-compatible (must have type='openai' or base_url)"
)
if not base_url:
raise ValueError("No base_url specified for OpenAI-compatible provider")
fetcher = ModelFetcher(timeout=timeout)
return fetcher.fetch_models(base_url, api_key)

View File

@@ -635,3 +635,113 @@ class UserConfigManager:
self.set("defaults.ports", ports)
return True
return False
# Model management methods
def list_provider_models(self, provider_name: str) -> List[Dict[str, str]]:
"""Get all models for a specific provider.
Args:
provider_name: Name of the provider
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
provider_config = self.get_provider(provider_name)
if not provider_config:
return []
models = provider_config.get("models", [])
normalized_models = []
for model in models:
if isinstance(model, str):
normalized_models.append({"id": model})
elif isinstance(model, dict):
model_id = model.get("id", "")
if model_id:
normalized_models.append({"id": model_id})
return normalized_models
def set_provider_models(
self, provider_name: str, models: List[Dict[str, str]]
) -> None:
"""Set the models for a specific provider.
Args:
provider_name: Name of the provider
models: List of model dictionaries with 'id' and optional 'name' keys
"""
provider_config = self.get_provider(provider_name)
if not provider_config:
return
# Normalize models - ensure each has id, name defaults to id
normalized_models = []
for model in models:
if isinstance(model, dict) and "id" in model:
normalized_model = {
"id": model["id"],
}
normalized_models.append(normalized_model)
provider_config["models"] = normalized_models
self.set(f"providers.{provider_name}", provider_config)
def add_provider_model(
self, provider_name: str, model_id: str, model_name: Optional[str] = None
) -> None:
"""Add a model to a provider.
Args:
provider_name: Name of the provider
model_id: ID of the model
model_name: Optional display name for the model (defaults to model_id)
"""
models = self.list_provider_models(provider_name)
for existing_model in models:
if existing_model["id"] == model_id:
return
new_model = {"id": model_id}
models.append(new_model)
self.set_provider_models(provider_name, models)
def remove_provider_model(self, provider_name: str, model_id: str) -> bool:
"""Remove a model from a provider.
Args:
provider_name: Name of the provider
model_id: ID of the model to remove
Returns:
True if model was removed, False if it didn't exist
"""
models = self.list_provider_models(provider_name)
original_length = len(models)
# Filter out the model with the specified ID
models = [model for model in models if model["id"] != model_id]
if len(models) < original_length:
self.set_provider_models(provider_name, models)
return True
return False
def is_provider_openai_compatible(self, provider_name: str) -> bool:
provider_config = self.get_provider(provider_name)
if not provider_config:
return False
provider_type = provider_config.get("type", "")
return provider_type == "openai" and provider_config.get("base_url") is not None
def list_openai_compatible_providers(self) -> List[str]:
providers = self.list_providers()
compatible_providers = []
for provider_name in providers.keys():
if self.is_provider_openai_compatible(provider_name):
compatible_providers.append(provider_name)
return compatible_providers

View File

@@ -15,6 +15,7 @@ dependencies = [
"rich>=13.6.0",
"pydantic>=2.5.0",
"questionary>=2.0.0",
"requests>=2.32.3",
]
classifiers = [
"Development Status :: 3 - Alpha",

2
uv.lock generated
View File

@@ -85,6 +85,7 @@ dependencies = [
{ name = "pydantic" },
{ name = "pyyaml" },
{ name = "questionary" },
{ name = "requests" },
{ name = "rich" },
{ name = "typer" },
]
@@ -109,6 +110,7 @@ requires-dist = [
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
{ name = "pyyaml", specifier = ">=6.0.1" },
{ name = "questionary", specifier = ">=2.0.0" },
{ name = "requests", specifier = ">=2.32.3" },
{ name = "rich", specifier = ">=13.6.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.9" },
{ name = "typer", specifier = ">=0.9.0" },