mirror of
https://github.com/Monadical-SAS/cubbi.git
synced 2025-12-20 12:19:07 +00:00
feat: dynamic model management for OpenAI-compatible providers (#33)
feat: add models fetch for openai-compatible endpoint
This commit is contained in:
137
cubbi/cli.py
137
cubbi/cli.py
@@ -762,6 +762,10 @@ config_app.add_typer(port_app, name="port", no_args_is_help=True)
|
||||
config_mcp_app = typer.Typer(help="Manage default MCP servers")
|
||||
config_app.add_typer(config_mcp_app, name="mcp", no_args_is_help=True)
|
||||
|
||||
# Create a models subcommand for config
|
||||
models_app = typer.Typer(help="Manage provider models")
|
||||
config_app.add_typer(models_app, name="models", no_args_is_help=True)
|
||||
|
||||
|
||||
# MCP configuration commands
|
||||
@config_mcp_app.command("list")
|
||||
@@ -2231,6 +2235,139 @@ exec npm start
|
||||
console.print("[green]MCP Inspector stopped[/green]")
|
||||
|
||||
|
||||
# Model management commands
|
||||
@models_app.command("list")
|
||||
def list_models(
|
||||
provider: Optional[str] = typer.Argument(None, help="Provider name (optional)"),
|
||||
) -> None:
|
||||
if provider:
|
||||
# List models for specific provider
|
||||
models = user_config.list_provider_models(provider)
|
||||
|
||||
if not models:
|
||||
if not user_config.get_provider(provider):
|
||||
console.print(f"[red]Provider '{provider}' not found[/red]")
|
||||
else:
|
||||
console.print(f"No models configured for provider '{provider}'")
|
||||
return
|
||||
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Model ID")
|
||||
|
||||
for model in models:
|
||||
table.add_row(model["id"])
|
||||
|
||||
console.print(f"\n[bold]Models for provider '{provider}'[/bold]")
|
||||
console.print(table)
|
||||
else:
|
||||
# List models for all providers
|
||||
providers = user_config.list_providers()
|
||||
|
||||
if not providers:
|
||||
console.print("No providers configured")
|
||||
return
|
||||
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Provider")
|
||||
table.add_column("Model ID")
|
||||
|
||||
found_models = False
|
||||
for provider_name in providers.keys():
|
||||
models = user_config.list_provider_models(provider_name)
|
||||
for model in models:
|
||||
table.add_row(provider_name, model["id"])
|
||||
found_models = True
|
||||
|
||||
if found_models:
|
||||
console.print(table)
|
||||
else:
|
||||
console.print("No models configured for any provider")
|
||||
|
||||
|
||||
@models_app.command("refresh")
|
||||
def refresh_models(
|
||||
provider: Optional[str] = typer.Argument(None, help="Provider name (optional)"),
|
||||
) -> None:
|
||||
from .model_fetcher import fetch_provider_models
|
||||
|
||||
if provider:
|
||||
# Refresh models for specific provider
|
||||
provider_config = user_config.get_provider(provider)
|
||||
if not provider_config:
|
||||
console.print(f"[red]Provider '{provider}' not found[/red]")
|
||||
return
|
||||
|
||||
if not user_config.is_provider_openai_compatible(provider):
|
||||
console.print(
|
||||
f"[red]Provider '{provider}' is not a custom OpenAI provider[/red]"
|
||||
)
|
||||
console.print(
|
||||
"Only providers with type='openai' and custom base_url are supported"
|
||||
)
|
||||
return
|
||||
|
||||
console.print(f"Refreshing models for provider '{provider}'...")
|
||||
|
||||
try:
|
||||
with console.status(f"Fetching models from {provider}..."):
|
||||
models = fetch_provider_models(provider_config)
|
||||
|
||||
user_config.set_provider_models(provider, models)
|
||||
console.print(
|
||||
f"[green]Successfully refreshed {len(models)} models for '{provider}'[/green]"
|
||||
)
|
||||
|
||||
# Show some examples
|
||||
if models:
|
||||
console.print("\nSample models:")
|
||||
for model in models[:5]: # Show first 5
|
||||
console.print(f" - {model['id']}")
|
||||
if len(models) > 5:
|
||||
console.print(f" ... and {len(models) - 5} more")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to refresh models for '{provider}': {e}[/red]")
|
||||
else:
|
||||
# Refresh models for all OpenAI-compatible providers
|
||||
compatible_providers = user_config.list_openai_compatible_providers()
|
||||
|
||||
if not compatible_providers:
|
||||
console.print("[yellow]No custom OpenAI providers found[/yellow]")
|
||||
console.print(
|
||||
"Add providers with type='openai' and custom base_url to refresh models"
|
||||
)
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"Refreshing models for {len(compatible_providers)} custom OpenAI providers..."
|
||||
)
|
||||
|
||||
success_count = 0
|
||||
failed_providers = []
|
||||
|
||||
for provider_name in compatible_providers:
|
||||
try:
|
||||
provider_config = user_config.get_provider(provider_name)
|
||||
with console.status(f"Fetching models from {provider_name}..."):
|
||||
models = fetch_provider_models(provider_config)
|
||||
|
||||
user_config.set_provider_models(provider_name, models)
|
||||
console.print(f"[green]✓ {provider_name}: {len(models)} models[/green]")
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ {provider_name}: {e}[/red]")
|
||||
failed_providers.append(provider_name)
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Summary[/bold]")
|
||||
console.print(f"Successfully refreshed: {success_count} providers")
|
||||
if failed_providers:
|
||||
console.print(
|
||||
f"Failed: {len(failed_providers)} providers ({', '.join(failed_providers)})"
|
||||
)
|
||||
|
||||
|
||||
def session_create_entry_point():
|
||||
"""Entry point that directly invokes 'cubbi session create'.
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ Interactive configuration tool for Cubbi providers and models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import docker
|
||||
import questionary
|
||||
@@ -164,7 +165,8 @@ class ProviderConfigurator:
|
||||
"How would you like to provide the API key?",
|
||||
choices=[
|
||||
"Enter API key directly (saved in config)",
|
||||
"Reference environment variable (recommended)",
|
||||
"Use environment variable (recommended)",
|
||||
"No API key needed",
|
||||
],
|
||||
).ask()
|
||||
|
||||
@@ -184,11 +186,12 @@ class ProviderConfigurator:
|
||||
|
||||
api_key = f"${{{env_var.strip()}}}"
|
||||
|
||||
# Check if the environment variable exists
|
||||
if not os.environ.get(env_var.strip()):
|
||||
console.print(
|
||||
f"[yellow]Warning: Environment variable '{env_var}' is not currently set[/yellow]"
|
||||
)
|
||||
elif "No API key" in api_key_choice:
|
||||
api_key = ""
|
||||
else:
|
||||
api_key = questionary.password(
|
||||
"Enter API key:",
|
||||
@@ -219,6 +222,13 @@ class ProviderConfigurator:
|
||||
|
||||
console.print(f"[green]Added provider '{provider_name}'[/green]")
|
||||
|
||||
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||
console.print("Refreshing models...")
|
||||
try:
|
||||
self._refresh_provider_models(provider_name)
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Could not refresh models: {e}[/yellow]")
|
||||
|
||||
def _edit_provider(self, provider_name: str) -> None:
|
||||
"""Edit an existing provider."""
|
||||
provider_config = self.user_config.get_provider(provider_name)
|
||||
@@ -226,36 +236,129 @@ class ProviderConfigurator:
|
||||
console.print(f"[red]Provider '{provider_name}' not found![/red]")
|
||||
return
|
||||
|
||||
choices = ["View configuration", "Remove provider", "---", "Back"]
|
||||
console.print(f"\n[bold]Configuration for '{provider_name}':[/bold]")
|
||||
for key, value in provider_config.items():
|
||||
if key == "api_key" and not value.startswith("${"):
|
||||
display_value = (
|
||||
f"{'*' * (len(value) - 4)}{value[-4:]}"
|
||||
if len(value) > 4
|
||||
else "****"
|
||||
)
|
||||
elif key == "models" and isinstance(value, list):
|
||||
if value:
|
||||
console.print(f" {key}:")
|
||||
for i, model in enumerate(value[:10]):
|
||||
if isinstance(model, dict):
|
||||
model_id = model.get("id", str(model))
|
||||
else:
|
||||
model_id = str(model)
|
||||
console.print(f" {i+1}. {model_id}")
|
||||
if len(value) > 10:
|
||||
console.print(
|
||||
f" ... and {len(value)-10} more ({len(value)} total)"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
display_value = "(no models configured)"
|
||||
else:
|
||||
display_value = value
|
||||
console.print(f" {key}: {display_value}")
|
||||
console.print()
|
||||
|
||||
while True:
|
||||
choices = ["Remove provider"]
|
||||
|
||||
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||
choices.append("Refresh models")
|
||||
|
||||
choices.extend(["---", "Back"])
|
||||
|
||||
choice = questionary.select(
|
||||
f"What would you like to do with '{provider_name}'?",
|
||||
choices=choices,
|
||||
).ask()
|
||||
|
||||
if choice == "View configuration":
|
||||
console.print(f"\n[bold]Configuration for '{provider_name}':[/bold]")
|
||||
for key, value in provider_config.items():
|
||||
if key == "api_key" and not value.startswith("${"):
|
||||
# Mask direct API keys
|
||||
display_value = (
|
||||
f"{'*' * (len(value) - 4)}{value[-4:]}"
|
||||
if len(value) > 4
|
||||
else "****"
|
||||
)
|
||||
else:
|
||||
display_value = value
|
||||
console.print(f" {key}: {display_value}")
|
||||
console.print()
|
||||
|
||||
elif choice == "Remove provider":
|
||||
if choice == "Remove provider":
|
||||
confirm = questionary.confirm(
|
||||
f"Are you sure you want to remove provider '{provider_name}'?"
|
||||
f"Are you sure you want to remove provider '{provider_name}'?",
|
||||
default=False,
|
||||
).ask()
|
||||
|
||||
if confirm:
|
||||
self.user_config.remove_provider(provider_name)
|
||||
console.print(f"[green]Removed provider '{provider_name}'[/green]")
|
||||
break
|
||||
|
||||
elif choice == "Refresh models":
|
||||
self._refresh_provider_models(provider_name)
|
||||
|
||||
elif choice == "Back" or choice is None:
|
||||
break
|
||||
|
||||
def _refresh_provider_models(self, provider_name: str) -> None:
|
||||
from .model_fetcher import fetch_provider_models
|
||||
|
||||
try:
|
||||
provider_config = self.user_config.get_provider(provider_name)
|
||||
console.print(f"Refreshing models for {provider_name}...")
|
||||
|
||||
models = fetch_provider_models(provider_config)
|
||||
self.user_config.set_provider_models(provider_name, models)
|
||||
|
||||
console.print(
|
||||
f"[green]Successfully refreshed {len(models)} models for '{provider_name}'[/green]"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to refresh models: {e}[/red]")
|
||||
|
||||
def _select_model_from_list(self, provider_name: str) -> Optional[str]:
|
||||
from .model_fetcher import fetch_provider_models
|
||||
|
||||
models = self.user_config.list_provider_models(provider_name)
|
||||
|
||||
if not models:
|
||||
console.print(f"No models found for {provider_name}. Refreshing...")
|
||||
try:
|
||||
provider_config = self.user_config.get_provider(provider_name)
|
||||
models = fetch_provider_models(provider_config)
|
||||
self.user_config.set_provider_models(provider_name, models)
|
||||
console.print(f"[green]Refreshed {len(models)} models[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to refresh models: {e}[/red]")
|
||||
return questionary.text(
|
||||
f"Enter model name for {provider_name}:",
|
||||
validate=lambda name: len(name.strip()) > 0
|
||||
or "Please enter a model name",
|
||||
).ask()
|
||||
|
||||
if not models:
|
||||
console.print(f"[yellow]No models available for {provider_name}[/yellow]")
|
||||
return questionary.text(
|
||||
f"Enter model name for {provider_name}:",
|
||||
validate=lambda name: len(name.strip()) > 0
|
||||
or "Please enter a model name",
|
||||
).ask()
|
||||
|
||||
model_choices = [model["id"] for model in models]
|
||||
model_choices.append("---")
|
||||
model_choices.append("Enter manually")
|
||||
|
||||
choice = questionary.select(
|
||||
f"Select a model for {provider_name}:",
|
||||
choices=model_choices,
|
||||
).ask()
|
||||
|
||||
if choice is None or choice == "---":
|
||||
return None
|
||||
elif choice == "Enter manually":
|
||||
return questionary.text(
|
||||
f"Enter model name for {provider_name}:",
|
||||
validate=lambda name: len(name.strip()) > 0
|
||||
or "Please enter a model name",
|
||||
).ask()
|
||||
else:
|
||||
return choice
|
||||
|
||||
def _set_default_model(self) -> None:
|
||||
"""Set the default model."""
|
||||
@@ -298,16 +401,18 @@ class ProviderConfigurator:
|
||||
# Extract provider name
|
||||
provider_name = choice.split(" (")[0]
|
||||
|
||||
# Ask for model name
|
||||
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||
model_name = self._select_model_from_list(provider_name)
|
||||
else:
|
||||
model_name = questionary.text(
|
||||
f"Enter model name for {provider_name} (e.g., 'claude-3-5-sonnet', 'gpt-4', 'llama3:70b'):",
|
||||
validate=lambda name: len(name.strip()) > 0 or "Please enter a model name",
|
||||
validate=lambda name: len(name.strip()) > 0
|
||||
or "Please enter a model name",
|
||||
).ask()
|
||||
|
||||
if model_name is None:
|
||||
return
|
||||
|
||||
# Set the default model in provider/model format
|
||||
default_model = f"{provider_name}/{model_name.strip()}"
|
||||
self.user_config.set("defaults.model", default_model)
|
||||
|
||||
@@ -659,6 +764,10 @@ class ProviderConfigurator:
|
||||
|
||||
elif "defaults" in choice:
|
||||
if is_default:
|
||||
confirm = questionary.confirm(
|
||||
f"Remove '{server_name}' from default MCPs?", default=False
|
||||
).ask()
|
||||
if confirm:
|
||||
self.user_config.remove_mcp(server_name)
|
||||
console.print(
|
||||
f"[green]Removed '{server_name}' from default MCPs[/green]"
|
||||
@@ -669,7 +778,8 @@ class ProviderConfigurator:
|
||||
|
||||
elif choice == "Remove server":
|
||||
confirm = questionary.confirm(
|
||||
f"Are you sure you want to remove MCP server '{server_name}'?"
|
||||
f"Are you sure you want to remove MCP server '{server_name}'?",
|
||||
default=False,
|
||||
).ask()
|
||||
|
||||
if confirm:
|
||||
@@ -749,7 +859,8 @@ class ProviderConfigurator:
|
||||
|
||||
elif choice == "Remove network":
|
||||
confirm = questionary.confirm(
|
||||
f"Are you sure you want to remove network '{network_name}'?"
|
||||
f"Are you sure you want to remove network '{network_name}'?",
|
||||
default=False,
|
||||
).ask()
|
||||
|
||||
if confirm:
|
||||
@@ -829,7 +940,8 @@ class ProviderConfigurator:
|
||||
|
||||
elif choice == "Remove volume":
|
||||
confirm = questionary.confirm(
|
||||
f"Are you sure you want to remove volume mapping '{volume_mapping}'?"
|
||||
f"Are you sure you want to remove volume mapping '{volume_mapping}'?",
|
||||
default=False,
|
||||
).ask()
|
||||
|
||||
if confirm:
|
||||
@@ -902,7 +1014,7 @@ class ProviderConfigurator:
|
||||
|
||||
if choice == "Remove port":
|
||||
confirm = questionary.confirm(
|
||||
f"Are you sure you want to remove port {port_num}?"
|
||||
f"Are you sure you want to remove port {port_num}?", default=False
|
||||
).ask()
|
||||
|
||||
if confirm:
|
||||
|
||||
@@ -116,6 +116,8 @@ class ContainerManager:
|
||||
}
|
||||
if provider.get("base_url"):
|
||||
provider_config["base_url"] = provider.get("base_url")
|
||||
if provider.get("models"):
|
||||
provider_config["models"] = provider.get("models")
|
||||
|
||||
providers[name] = provider_config
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class ProviderConfig(BaseModel):
|
||||
type: str
|
||||
api_key: str
|
||||
base_url: str | None = None
|
||||
models: list[dict[str, str]] = []
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
|
||||
@@ -51,12 +51,21 @@ class OpencodePlugin(ToolPlugin):
|
||||
# Check if this is a custom provider (has baseURL)
|
||||
if provider_config.base_url:
|
||||
# Custom provider - include baseURL and name
|
||||
models_dict = {}
|
||||
|
||||
# Add all models for OpenAI-compatible providers
|
||||
if provider_config.type == "openai" and provider_config.models:
|
||||
for model in provider_config.models:
|
||||
model_id = model.get("id", "")
|
||||
if model_id:
|
||||
models_dict[model_id] = {"name": model_id}
|
||||
|
||||
provider_entry: dict[str, str | dict[str, str]] = {
|
||||
"options": {
|
||||
"apiKey": provider_config.api_key,
|
||||
"baseURL": provider_config.base_url,
|
||||
},
|
||||
"models": {},
|
||||
"models": models_dict,
|
||||
}
|
||||
|
||||
# Add npm package and name for custom providers
|
||||
@@ -68,6 +77,10 @@ class OpencodePlugin(ToolPlugin):
|
||||
elif provider_config.type == "openai":
|
||||
provider_entry["npm"] = "@ai-sdk/openai-compatible"
|
||||
provider_entry["name"] = f"OpenAI Compatible ({provider_name})"
|
||||
if models_dict:
|
||||
self.status.log(
|
||||
f"Added {len(models_dict)} models to {provider_name}"
|
||||
)
|
||||
elif provider_config.type == "google":
|
||||
provider_entry["npm"] = "@ai-sdk/google"
|
||||
provider_entry["name"] = f"Google ({provider_name})"
|
||||
@@ -94,16 +107,19 @@ class OpencodePlugin(ToolPlugin):
|
||||
f"Added {provider_name} standard provider to OpenCode configuration"
|
||||
)
|
||||
|
||||
# Set default model and add it only to the default provider
|
||||
# Set default model
|
||||
if cubbi_config.defaults.model:
|
||||
config_data["model"] = cubbi_config.defaults.model
|
||||
self.status.log(f"Set default model to {config_data['model']}")
|
||||
|
||||
# Add the specific model only to the provider that matches the default model
|
||||
# Add the default model to provider if it doesn't already have models
|
||||
provider_name: str
|
||||
model_name: str
|
||||
provider_name, model_name = cubbi_config.defaults.model.split("/", 1)
|
||||
if provider_name in config_data["provider"]:
|
||||
provider_config = cubbi_config.providers.get(provider_name)
|
||||
# Only add default model if provider doesn't already have models populated
|
||||
if not (provider_config and provider_config.models):
|
||||
config_data["provider"][provider_name]["models"] = {
|
||||
model_name: {"name": model_name}
|
||||
}
|
||||
|
||||
208
cubbi/model_fetcher.py
Normal file
208
cubbi/model_fetcher.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Model fetching utilities for OpenAI-compatible providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFetcher:
|
||||
"""Fetches model lists from OpenAI-compatible API endpoints."""
|
||||
|
||||
def __init__(self, timeout: int = 30):
|
||||
"""Initialize the model fetcher.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.timeout = timeout
|
||||
|
||||
def fetch_models(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Fetch models from an OpenAI-compatible /v1/models endpoint.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the provider (e.g., "https://api.openai.com" or "https://api.litellm.com")
|
||||
api_key: Optional API key for authentication
|
||||
headers: Optional additional headers
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If the request fails
|
||||
ValueError: If the response format is invalid
|
||||
"""
|
||||
# Construct the models endpoint URL
|
||||
models_url = self._build_models_url(base_url)
|
||||
|
||||
# Prepare headers
|
||||
request_headers = self._build_headers(api_key, headers)
|
||||
|
||||
logger.info(f"Fetching models from {models_url}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
models_url, headers=request_headers, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse JSON response
|
||||
data = response.json()
|
||||
|
||||
# Validate response structure
|
||||
if not isinstance(data, dict) or "data" not in data:
|
||||
raise ValueError(
|
||||
f"Invalid response format: expected dict with 'data' key, got {type(data)}"
|
||||
)
|
||||
|
||||
models_data = data["data"]
|
||||
if not isinstance(models_data, list):
|
||||
raise ValueError(
|
||||
f"Invalid models data: expected list, got {type(models_data)}"
|
||||
)
|
||||
|
||||
# Process models
|
||||
models = []
|
||||
for model_item in models_data:
|
||||
if not isinstance(model_item, dict):
|
||||
continue
|
||||
|
||||
model_id = model_item.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip models with * in their ID as requested
|
||||
if "*" in model_id:
|
||||
logger.debug(f"Skipping model with wildcard: {model_id}")
|
||||
continue
|
||||
|
||||
# Create model entry
|
||||
model = {
|
||||
"id": model_id,
|
||||
}
|
||||
models.append(model)
|
||||
|
||||
logger.info(f"Successfully fetched {len(models)} models from {base_url}")
|
||||
return models
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timed out after {self.timeout} seconds")
|
||||
raise requests.RequestException(f"Request to {models_url} timed out")
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(f"Connection error: {e}")
|
||||
raise requests.RequestException(f"Failed to connect to {models_url}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"HTTP error {e.response.status_code}: {e}")
|
||||
if e.response.status_code == 401:
|
||||
raise requests.RequestException(
|
||||
"Authentication failed: invalid API key"
|
||||
)
|
||||
elif e.response.status_code == 403:
|
||||
raise requests.RequestException(
|
||||
"Access forbidden: check API key permissions"
|
||||
)
|
||||
else:
|
||||
raise requests.RequestException(
|
||||
f"HTTP {e.response.status_code} error from {models_url}"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JSON response: {e}")
|
||||
raise ValueError(f"Invalid JSON response from {models_url}")
|
||||
|
||||
def _build_models_url(self, base_url: str) -> str:
|
||||
"""Build the models endpoint URL from a base URL.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the provider
|
||||
|
||||
Returns:
|
||||
Complete URL for the /v1/models endpoint
|
||||
"""
|
||||
# Remove trailing slash if present
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
# Add /v1/models if not already present
|
||||
if not base_url.endswith("/v1/models"):
|
||||
if base_url.endswith("/v1"):
|
||||
base_url += "/models"
|
||||
else:
|
||||
base_url += "/v1/models"
|
||||
|
||||
return base_url
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
additional_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Build request headers.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key for authentication
|
||||
additional_headers: Optional additional headers
|
||||
|
||||
Returns:
|
||||
Dictionary of headers
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
# Add authentication header if API key is provided
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def fetch_provider_models(
|
||||
provider_config: Dict, timeout: int = 30
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Convenience function to fetch models for a provider configuration.
|
||||
|
||||
Args:
|
||||
provider_config: Provider configuration dictionary
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of model dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not OpenAI-compatible or missing required fields
|
||||
requests.RequestException: If the request fails
|
||||
"""
|
||||
import os
|
||||
|
||||
provider_type = provider_config.get("type", "")
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key", "")
|
||||
|
||||
if api_key.startswith("${") and api_key.endswith("}"):
|
||||
env_var_name = api_key[2:-1]
|
||||
api_key = os.environ.get(env_var_name, "")
|
||||
|
||||
if provider_type != "openai" and not base_url:
|
||||
raise ValueError(
|
||||
"Provider is not OpenAI-compatible (must have type='openai' or base_url)"
|
||||
)
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("No base_url specified for OpenAI-compatible provider")
|
||||
|
||||
fetcher = ModelFetcher(timeout=timeout)
|
||||
return fetcher.fetch_models(base_url, api_key)
|
||||
@@ -635,3 +635,113 @@ class UserConfigManager:
|
||||
self.set("defaults.ports", ports)
|
||||
return True
|
||||
return False
|
||||
|
||||
# Model management methods
|
||||
def list_provider_models(self, provider_name: str) -> List[Dict[str, str]]:
|
||||
"""Get all models for a specific provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
provider_config = self.get_provider(provider_name)
|
||||
if not provider_config:
|
||||
return []
|
||||
|
||||
models = provider_config.get("models", [])
|
||||
normalized_models = []
|
||||
for model in models:
|
||||
if isinstance(model, str):
|
||||
normalized_models.append({"id": model})
|
||||
elif isinstance(model, dict):
|
||||
model_id = model.get("id", "")
|
||||
if model_id:
|
||||
normalized_models.append({"id": model_id})
|
||||
|
||||
return normalized_models
|
||||
|
||||
def set_provider_models(
|
||||
self, provider_name: str, models: List[Dict[str, str]]
|
||||
) -> None:
|
||||
"""Set the models for a specific provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
models: List of model dictionaries with 'id' and optional 'name' keys
|
||||
"""
|
||||
provider_config = self.get_provider(provider_name)
|
||||
if not provider_config:
|
||||
return
|
||||
|
||||
# Normalize models - ensure each has id, name defaults to id
|
||||
normalized_models = []
|
||||
for model in models:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
normalized_model = {
|
||||
"id": model["id"],
|
||||
}
|
||||
normalized_models.append(normalized_model)
|
||||
|
||||
provider_config["models"] = normalized_models
|
||||
self.set(f"providers.{provider_name}", provider_config)
|
||||
|
||||
def add_provider_model(
|
||||
self, provider_name: str, model_id: str, model_name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Add a model to a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
model_id: ID of the model
|
||||
model_name: Optional display name for the model (defaults to model_id)
|
||||
"""
|
||||
models = self.list_provider_models(provider_name)
|
||||
|
||||
for existing_model in models:
|
||||
if existing_model["id"] == model_id:
|
||||
return
|
||||
|
||||
new_model = {"id": model_id}
|
||||
models.append(new_model)
|
||||
self.set_provider_models(provider_name, models)
|
||||
|
||||
def remove_provider_model(self, provider_name: str, model_id: str) -> bool:
|
||||
"""Remove a model from a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
model_id: ID of the model to remove
|
||||
|
||||
Returns:
|
||||
True if model was removed, False if it didn't exist
|
||||
"""
|
||||
models = self.list_provider_models(provider_name)
|
||||
original_length = len(models)
|
||||
|
||||
# Filter out the model with the specified ID
|
||||
models = [model for model in models if model["id"] != model_id]
|
||||
|
||||
if len(models) < original_length:
|
||||
self.set_provider_models(provider_name, models)
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_provider_openai_compatible(self, provider_name: str) -> bool:
|
||||
provider_config = self.get_provider(provider_name)
|
||||
if not provider_config:
|
||||
return False
|
||||
|
||||
provider_type = provider_config.get("type", "")
|
||||
return provider_type == "openai" and provider_config.get("base_url") is not None
|
||||
|
||||
def list_openai_compatible_providers(self) -> List[str]:
|
||||
providers = self.list_providers()
|
||||
compatible_providers = []
|
||||
|
||||
for provider_name in providers.keys():
|
||||
if self.is_provider_openai_compatible(provider_name):
|
||||
compatible_providers.append(provider_name)
|
||||
|
||||
return compatible_providers
|
||||
|
||||
@@ -15,6 +15,7 @@ dependencies = [
|
||||
"rich>=13.6.0",
|
||||
"pydantic>=2.5.0",
|
||||
"questionary>=2.0.0",
|
||||
"requests>=2.32.3",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -85,6 +85,7 @@ dependencies = [
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "questionary" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
@@ -109,6 +110,7 @@ requires-dist = [
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.1" },
|
||||
{ name = "questionary", specifier = ">=2.0.0" },
|
||||
{ name = "requests", specifier = ">=2.32.3" },
|
||||
{ name = "rich", specifier = ">=13.6.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.9" },
|
||||
{ name = "typer", specifier = ">=0.9.0" },
|
||||
|
||||
Reference in New Issue
Block a user