mirror of
https://github.com/Monadical-SAS/cubbi.git
synced 2025-12-20 12:19:07 +00:00
feat: dynamic model management for OpenAI-compatible providers (#33)
feat: add models fetch for openai-compatible endpoint
This commit is contained in:
137
cubbi/cli.py
137
cubbi/cli.py
@@ -762,6 +762,10 @@ config_app.add_typer(port_app, name="port", no_args_is_help=True)
|
|||||||
config_mcp_app = typer.Typer(help="Manage default MCP servers")
|
config_mcp_app = typer.Typer(help="Manage default MCP servers")
|
||||||
config_app.add_typer(config_mcp_app, name="mcp", no_args_is_help=True)
|
config_app.add_typer(config_mcp_app, name="mcp", no_args_is_help=True)
|
||||||
|
|
||||||
|
# Create a models subcommand for config
|
||||||
|
models_app = typer.Typer(help="Manage provider models")
|
||||||
|
config_app.add_typer(models_app, name="models", no_args_is_help=True)
|
||||||
|
|
||||||
|
|
||||||
# MCP configuration commands
|
# MCP configuration commands
|
||||||
@config_mcp_app.command("list")
|
@config_mcp_app.command("list")
|
||||||
@@ -2231,6 +2235,139 @@ exec npm start
|
|||||||
console.print("[green]MCP Inspector stopped[/green]")
|
console.print("[green]MCP Inspector stopped[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
# Model management commands
|
||||||
|
@models_app.command("list")
|
||||||
|
def list_models(
|
||||||
|
provider: Optional[str] = typer.Argument(None, help="Provider name (optional)"),
|
||||||
|
) -> None:
|
||||||
|
if provider:
|
||||||
|
# List models for specific provider
|
||||||
|
models = user_config.list_provider_models(provider)
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
if not user_config.get_provider(provider):
|
||||||
|
console.print(f"[red]Provider '{provider}' not found[/red]")
|
||||||
|
else:
|
||||||
|
console.print(f"No models configured for provider '{provider}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(show_header=True, header_style="bold")
|
||||||
|
table.add_column("Model ID")
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
table.add_row(model["id"])
|
||||||
|
|
||||||
|
console.print(f"\n[bold]Models for provider '{provider}'[/bold]")
|
||||||
|
console.print(table)
|
||||||
|
else:
|
||||||
|
# List models for all providers
|
||||||
|
providers = user_config.list_providers()
|
||||||
|
|
||||||
|
if not providers:
|
||||||
|
console.print("No providers configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(show_header=True, header_style="bold")
|
||||||
|
table.add_column("Provider")
|
||||||
|
table.add_column("Model ID")
|
||||||
|
|
||||||
|
found_models = False
|
||||||
|
for provider_name in providers.keys():
|
||||||
|
models = user_config.list_provider_models(provider_name)
|
||||||
|
for model in models:
|
||||||
|
table.add_row(provider_name, model["id"])
|
||||||
|
found_models = True
|
||||||
|
|
||||||
|
if found_models:
|
||||||
|
console.print(table)
|
||||||
|
else:
|
||||||
|
console.print("No models configured for any provider")
|
||||||
|
|
||||||
|
|
||||||
|
@models_app.command("refresh")
|
||||||
|
def refresh_models(
|
||||||
|
provider: Optional[str] = typer.Argument(None, help="Provider name (optional)"),
|
||||||
|
) -> None:
|
||||||
|
from .model_fetcher import fetch_provider_models
|
||||||
|
|
||||||
|
if provider:
|
||||||
|
# Refresh models for specific provider
|
||||||
|
provider_config = user_config.get_provider(provider)
|
||||||
|
if not provider_config:
|
||||||
|
console.print(f"[red]Provider '{provider}' not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not user_config.is_provider_openai_compatible(provider):
|
||||||
|
console.print(
|
||||||
|
f"[red]Provider '{provider}' is not a custom OpenAI provider[/red]"
|
||||||
|
)
|
||||||
|
console.print(
|
||||||
|
"Only providers with type='openai' and custom base_url are supported"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"Refreshing models for provider '{provider}'...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with console.status(f"Fetching models from {provider}..."):
|
||||||
|
models = fetch_provider_models(provider_config)
|
||||||
|
|
||||||
|
user_config.set_provider_models(provider, models)
|
||||||
|
console.print(
|
||||||
|
f"[green]Successfully refreshed {len(models)} models for '{provider}'[/green]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show some examples
|
||||||
|
if models:
|
||||||
|
console.print("\nSample models:")
|
||||||
|
for model in models[:5]: # Show first 5
|
||||||
|
console.print(f" - {model['id']}")
|
||||||
|
if len(models) > 5:
|
||||||
|
console.print(f" ... and {len(models) - 5} more")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Failed to refresh models for '{provider}': {e}[/red]")
|
||||||
|
else:
|
||||||
|
# Refresh models for all OpenAI-compatible providers
|
||||||
|
compatible_providers = user_config.list_openai_compatible_providers()
|
||||||
|
|
||||||
|
if not compatible_providers:
|
||||||
|
console.print("[yellow]No custom OpenAI providers found[/yellow]")
|
||||||
|
console.print(
|
||||||
|
"Add providers with type='openai' and custom base_url to refresh models"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"Refreshing models for {len(compatible_providers)} custom OpenAI providers..."
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
failed_providers = []
|
||||||
|
|
||||||
|
for provider_name in compatible_providers:
|
||||||
|
try:
|
||||||
|
provider_config = user_config.get_provider(provider_name)
|
||||||
|
with console.status(f"Fetching models from {provider_name}..."):
|
||||||
|
models = fetch_provider_models(provider_config)
|
||||||
|
|
||||||
|
user_config.set_provider_models(provider_name, models)
|
||||||
|
console.print(f"[green]✓ {provider_name}: {len(models)} models[/green]")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ {provider_name}: {e}[/red]")
|
||||||
|
failed_providers.append(provider_name)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
console.print("\n[bold]Summary[/bold]")
|
||||||
|
console.print(f"Successfully refreshed: {success_count} providers")
|
||||||
|
if failed_providers:
|
||||||
|
console.print(
|
||||||
|
f"Failed: {len(failed_providers)} providers ({', '.join(failed_providers)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def session_create_entry_point():
|
def session_create_entry_point():
|
||||||
"""Entry point that directly invokes 'cubbi session create'.
|
"""Entry point that directly invokes 'cubbi session create'.
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ Interactive configuration tool for Cubbi providers and models.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import docker
|
import docker
|
||||||
import questionary
|
import questionary
|
||||||
@@ -164,7 +165,8 @@ class ProviderConfigurator:
|
|||||||
"How would you like to provide the API key?",
|
"How would you like to provide the API key?",
|
||||||
choices=[
|
choices=[
|
||||||
"Enter API key directly (saved in config)",
|
"Enter API key directly (saved in config)",
|
||||||
"Reference environment variable (recommended)",
|
"Use environment variable (recommended)",
|
||||||
|
"No API key needed",
|
||||||
],
|
],
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
@@ -184,11 +186,12 @@ class ProviderConfigurator:
|
|||||||
|
|
||||||
api_key = f"${{{env_var.strip()}}}"
|
api_key = f"${{{env_var.strip()}}}"
|
||||||
|
|
||||||
# Check if the environment variable exists
|
|
||||||
if not os.environ.get(env_var.strip()):
|
if not os.environ.get(env_var.strip()):
|
||||||
console.print(
|
console.print(
|
||||||
f"[yellow]Warning: Environment variable '{env_var}' is not currently set[/yellow]"
|
f"[yellow]Warning: Environment variable '{env_var}' is not currently set[/yellow]"
|
||||||
)
|
)
|
||||||
|
elif "No API key" in api_key_choice:
|
||||||
|
api_key = ""
|
||||||
else:
|
else:
|
||||||
api_key = questionary.password(
|
api_key = questionary.password(
|
||||||
"Enter API key:",
|
"Enter API key:",
|
||||||
@@ -219,6 +222,13 @@ class ProviderConfigurator:
|
|||||||
|
|
||||||
console.print(f"[green]Added provider '{provider_name}'[/green]")
|
console.print(f"[green]Added provider '{provider_name}'[/green]")
|
||||||
|
|
||||||
|
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||||
|
console.print("Refreshing models...")
|
||||||
|
try:
|
||||||
|
self._refresh_provider_models(provider_name)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[yellow]Could not refresh models: {e}[/yellow]")
|
||||||
|
|
||||||
def _edit_provider(self, provider_name: str) -> None:
|
def _edit_provider(self, provider_name: str) -> None:
|
||||||
"""Edit an existing provider."""
|
"""Edit an existing provider."""
|
||||||
provider_config = self.user_config.get_provider(provider_name)
|
provider_config = self.user_config.get_provider(provider_name)
|
||||||
@@ -226,36 +236,129 @@ class ProviderConfigurator:
|
|||||||
console.print(f"[red]Provider '{provider_name}' not found![/red]")
|
console.print(f"[red]Provider '{provider_name}' not found![/red]")
|
||||||
return
|
return
|
||||||
|
|
||||||
choices = ["View configuration", "Remove provider", "---", "Back"]
|
console.print(f"\n[bold]Configuration for '{provider_name}':[/bold]")
|
||||||
|
for key, value in provider_config.items():
|
||||||
choice = questionary.select(
|
if key == "api_key" and not value.startswith("${"):
|
||||||
f"What would you like to do with '{provider_name}'?",
|
display_value = (
|
||||||
choices=choices,
|
f"{'*' * (len(value) - 4)}{value[-4:]}"
|
||||||
).ask()
|
if len(value) > 4
|
||||||
|
else "****"
|
||||||
if choice == "View configuration":
|
)
|
||||||
console.print(f"\n[bold]Configuration for '{provider_name}':[/bold]")
|
elif key == "models" and isinstance(value, list):
|
||||||
for key, value in provider_config.items():
|
if value:
|
||||||
if key == "api_key" and not value.startswith("${"):
|
console.print(f" {key}:")
|
||||||
# Mask direct API keys
|
for i, model in enumerate(value[:10]):
|
||||||
display_value = (
|
if isinstance(model, dict):
|
||||||
f"{'*' * (len(value) - 4)}{value[-4:]}"
|
model_id = model.get("id", str(model))
|
||||||
if len(value) > 4
|
else:
|
||||||
else "****"
|
model_id = str(model)
|
||||||
)
|
console.print(f" {i+1}. {model_id}")
|
||||||
|
if len(value) > 10:
|
||||||
|
console.print(
|
||||||
|
f" ... and {len(value)-10} more ({len(value)} total)"
|
||||||
|
)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
display_value = value
|
display_value = "(no models configured)"
|
||||||
console.print(f" {key}: {display_value}")
|
else:
|
||||||
console.print()
|
display_value = value
|
||||||
|
console.print(f" {key}: {display_value}")
|
||||||
|
console.print()
|
||||||
|
|
||||||
elif choice == "Remove provider":
|
while True:
|
||||||
confirm = questionary.confirm(
|
choices = ["Remove provider"]
|
||||||
f"Are you sure you want to remove provider '{provider_name}'?"
|
|
||||||
|
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||||
|
choices.append("Refresh models")
|
||||||
|
|
||||||
|
choices.extend(["---", "Back"])
|
||||||
|
|
||||||
|
choice = questionary.select(
|
||||||
|
f"What would you like to do with '{provider_name}'?",
|
||||||
|
choices=choices,
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if confirm:
|
if choice == "Remove provider":
|
||||||
self.user_config.remove_provider(provider_name)
|
confirm = questionary.confirm(
|
||||||
console.print(f"[green]Removed provider '{provider_name}'[/green]")
|
f"Are you sure you want to remove provider '{provider_name}'?",
|
||||||
|
default=False,
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if confirm:
|
||||||
|
self.user_config.remove_provider(provider_name)
|
||||||
|
console.print(f"[green]Removed provider '{provider_name}'[/green]")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif choice == "Refresh models":
|
||||||
|
self._refresh_provider_models(provider_name)
|
||||||
|
|
||||||
|
elif choice == "Back" or choice is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _refresh_provider_models(self, provider_name: str) -> None:
|
||||||
|
from .model_fetcher import fetch_provider_models
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_config = self.user_config.get_provider(provider_name)
|
||||||
|
console.print(f"Refreshing models for {provider_name}...")
|
||||||
|
|
||||||
|
models = fetch_provider_models(provider_config)
|
||||||
|
self.user_config.set_provider_models(provider_name, models)
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"[green]Successfully refreshed {len(models)} models for '{provider_name}'[/green]"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Failed to refresh models: {e}[/red]")
|
||||||
|
|
||||||
|
def _select_model_from_list(self, provider_name: str) -> Optional[str]:
|
||||||
|
from .model_fetcher import fetch_provider_models
|
||||||
|
|
||||||
|
models = self.user_config.list_provider_models(provider_name)
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
console.print(f"No models found for {provider_name}. Refreshing...")
|
||||||
|
try:
|
||||||
|
provider_config = self.user_config.get_provider(provider_name)
|
||||||
|
models = fetch_provider_models(provider_config)
|
||||||
|
self.user_config.set_provider_models(provider_name, models)
|
||||||
|
console.print(f"[green]Refreshed {len(models)} models[/green]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Failed to refresh models: {e}[/red]")
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter model name for {provider_name}:",
|
||||||
|
validate=lambda name: len(name.strip()) > 0
|
||||||
|
or "Please enter a model name",
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
console.print(f"[yellow]No models available for {provider_name}[/yellow]")
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter model name for {provider_name}:",
|
||||||
|
validate=lambda name: len(name.strip()) > 0
|
||||||
|
or "Please enter a model name",
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
model_choices = [model["id"] for model in models]
|
||||||
|
model_choices.append("---")
|
||||||
|
model_choices.append("Enter manually")
|
||||||
|
|
||||||
|
choice = questionary.select(
|
||||||
|
f"Select a model for {provider_name}:",
|
||||||
|
choices=model_choices,
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice is None or choice == "---":
|
||||||
|
return None
|
||||||
|
elif choice == "Enter manually":
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter model name for {provider_name}:",
|
||||||
|
validate=lambda name: len(name.strip()) > 0
|
||||||
|
or "Please enter a model name",
|
||||||
|
).ask()
|
||||||
|
else:
|
||||||
|
return choice
|
||||||
|
|
||||||
def _set_default_model(self) -> None:
|
def _set_default_model(self) -> None:
|
||||||
"""Set the default model."""
|
"""Set the default model."""
|
||||||
@@ -298,16 +401,18 @@ class ProviderConfigurator:
|
|||||||
# Extract provider name
|
# Extract provider name
|
||||||
provider_name = choice.split(" (")[0]
|
provider_name = choice.split(" (")[0]
|
||||||
|
|
||||||
# Ask for model name
|
if self.user_config.is_provider_openai_compatible(provider_name):
|
||||||
model_name = questionary.text(
|
model_name = self._select_model_from_list(provider_name)
|
||||||
f"Enter model name for {provider_name} (e.g., 'claude-3-5-sonnet', 'gpt-4', 'llama3:70b'):",
|
else:
|
||||||
validate=lambda name: len(name.strip()) > 0 or "Please enter a model name",
|
model_name = questionary.text(
|
||||||
).ask()
|
f"Enter model name for {provider_name} (e.g., 'claude-3-5-sonnet', 'gpt-4', 'llama3:70b'):",
|
||||||
|
validate=lambda name: len(name.strip()) > 0
|
||||||
|
or "Please enter a model name",
|
||||||
|
).ask()
|
||||||
|
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Set the default model in provider/model format
|
|
||||||
default_model = f"{provider_name}/{model_name.strip()}"
|
default_model = f"{provider_name}/{model_name.strip()}"
|
||||||
self.user_config.set("defaults.model", default_model)
|
self.user_config.set("defaults.model", default_model)
|
||||||
|
|
||||||
@@ -659,17 +764,22 @@ class ProviderConfigurator:
|
|||||||
|
|
||||||
elif "defaults" in choice:
|
elif "defaults" in choice:
|
||||||
if is_default:
|
if is_default:
|
||||||
self.user_config.remove_mcp(server_name)
|
confirm = questionary.confirm(
|
||||||
console.print(
|
f"Remove '{server_name}' from default MCPs?", default=False
|
||||||
f"[green]Removed '{server_name}' from default MCPs[/green]"
|
).ask()
|
||||||
)
|
if confirm:
|
||||||
|
self.user_config.remove_mcp(server_name)
|
||||||
|
console.print(
|
||||||
|
f"[green]Removed '{server_name}' from default MCPs[/green]"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.user_config.add_mcp(server_name)
|
self.user_config.add_mcp(server_name)
|
||||||
console.print(f"[green]Added '{server_name}' to default MCPs[/green]")
|
console.print(f"[green]Added '{server_name}' to default MCPs[/green]")
|
||||||
|
|
||||||
elif choice == "Remove server":
|
elif choice == "Remove server":
|
||||||
confirm = questionary.confirm(
|
confirm = questionary.confirm(
|
||||||
f"Are you sure you want to remove MCP server '{server_name}'?"
|
f"Are you sure you want to remove MCP server '{server_name}'?",
|
||||||
|
default=False,
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if confirm:
|
if confirm:
|
||||||
@@ -749,7 +859,8 @@ class ProviderConfigurator:
|
|||||||
|
|
||||||
elif choice == "Remove network":
|
elif choice == "Remove network":
|
||||||
confirm = questionary.confirm(
|
confirm = questionary.confirm(
|
||||||
f"Are you sure you want to remove network '{network_name}'?"
|
f"Are you sure you want to remove network '{network_name}'?",
|
||||||
|
default=False,
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if confirm:
|
if confirm:
|
||||||
@@ -829,7 +940,8 @@ class ProviderConfigurator:
|
|||||||
|
|
||||||
elif choice == "Remove volume":
|
elif choice == "Remove volume":
|
||||||
confirm = questionary.confirm(
|
confirm = questionary.confirm(
|
||||||
f"Are you sure you want to remove volume mapping '{volume_mapping}'?"
|
f"Are you sure you want to remove volume mapping '{volume_mapping}'?",
|
||||||
|
default=False,
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if confirm:
|
if confirm:
|
||||||
@@ -902,7 +1014,7 @@ class ProviderConfigurator:
|
|||||||
|
|
||||||
if choice == "Remove port":
|
if choice == "Remove port":
|
||||||
confirm = questionary.confirm(
|
confirm = questionary.confirm(
|
||||||
f"Are you sure you want to remove port {port_num}?"
|
f"Are you sure you want to remove port {port_num}?", default=False
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if confirm:
|
if confirm:
|
||||||
|
|||||||
@@ -116,6 +116,8 @@ class ContainerManager:
|
|||||||
}
|
}
|
||||||
if provider.get("base_url"):
|
if provider.get("base_url"):
|
||||||
provider_config["base_url"] = provider.get("base_url")
|
provider_config["base_url"] = provider.get("base_url")
|
||||||
|
if provider.get("models"):
|
||||||
|
provider_config["models"] = provider.get("models")
|
||||||
|
|
||||||
providers[name] = provider_config
|
providers[name] = provider_config
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class ProviderConfig(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
api_key: str
|
api_key: str
|
||||||
base_url: str | None = None
|
base_url: str | None = None
|
||||||
|
models: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
|
||||||
class MCPConfig(BaseModel):
|
class MCPConfig(BaseModel):
|
||||||
|
|||||||
@@ -51,12 +51,21 @@ class OpencodePlugin(ToolPlugin):
|
|||||||
# Check if this is a custom provider (has baseURL)
|
# Check if this is a custom provider (has baseURL)
|
||||||
if provider_config.base_url:
|
if provider_config.base_url:
|
||||||
# Custom provider - include baseURL and name
|
# Custom provider - include baseURL and name
|
||||||
|
models_dict = {}
|
||||||
|
|
||||||
|
# Add all models for OpenAI-compatible providers
|
||||||
|
if provider_config.type == "openai" and provider_config.models:
|
||||||
|
for model in provider_config.models:
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
if model_id:
|
||||||
|
models_dict[model_id] = {"name": model_id}
|
||||||
|
|
||||||
provider_entry: dict[str, str | dict[str, str]] = {
|
provider_entry: dict[str, str | dict[str, str]] = {
|
||||||
"options": {
|
"options": {
|
||||||
"apiKey": provider_config.api_key,
|
"apiKey": provider_config.api_key,
|
||||||
"baseURL": provider_config.base_url,
|
"baseURL": provider_config.base_url,
|
||||||
},
|
},
|
||||||
"models": {},
|
"models": models_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add npm package and name for custom providers
|
# Add npm package and name for custom providers
|
||||||
@@ -68,6 +77,10 @@ class OpencodePlugin(ToolPlugin):
|
|||||||
elif provider_config.type == "openai":
|
elif provider_config.type == "openai":
|
||||||
provider_entry["npm"] = "@ai-sdk/openai-compatible"
|
provider_entry["npm"] = "@ai-sdk/openai-compatible"
|
||||||
provider_entry["name"] = f"OpenAI Compatible ({provider_name})"
|
provider_entry["name"] = f"OpenAI Compatible ({provider_name})"
|
||||||
|
if models_dict:
|
||||||
|
self.status.log(
|
||||||
|
f"Added {len(models_dict)} models to {provider_name}"
|
||||||
|
)
|
||||||
elif provider_config.type == "google":
|
elif provider_config.type == "google":
|
||||||
provider_entry["npm"] = "@ai-sdk/google"
|
provider_entry["npm"] = "@ai-sdk/google"
|
||||||
provider_entry["name"] = f"Google ({provider_name})"
|
provider_entry["name"] = f"Google ({provider_name})"
|
||||||
@@ -94,22 +107,25 @@ class OpencodePlugin(ToolPlugin):
|
|||||||
f"Added {provider_name} standard provider to OpenCode configuration"
|
f"Added {provider_name} standard provider to OpenCode configuration"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set default model and add it only to the default provider
|
# Set default model
|
||||||
if cubbi_config.defaults.model:
|
if cubbi_config.defaults.model:
|
||||||
config_data["model"] = cubbi_config.defaults.model
|
config_data["model"] = cubbi_config.defaults.model
|
||||||
self.status.log(f"Set default model to {config_data['model']}")
|
self.status.log(f"Set default model to {config_data['model']}")
|
||||||
|
|
||||||
# Add the specific model only to the provider that matches the default model
|
# Add the default model to provider if it doesn't already have models
|
||||||
provider_name: str
|
provider_name: str
|
||||||
model_name: str
|
model_name: str
|
||||||
provider_name, model_name = cubbi_config.defaults.model.split("/", 1)
|
provider_name, model_name = cubbi_config.defaults.model.split("/", 1)
|
||||||
if provider_name in config_data["provider"]:
|
if provider_name in config_data["provider"]:
|
||||||
config_data["provider"][provider_name]["models"] = {
|
provider_config = cubbi_config.providers.get(provider_name)
|
||||||
model_name: {"name": model_name}
|
# Only add default model if provider doesn't already have models populated
|
||||||
}
|
if not (provider_config and provider_config.models):
|
||||||
self.status.log(
|
config_data["provider"][provider_name]["models"] = {
|
||||||
f"Added default model {model_name} to {provider_name} provider"
|
model_name: {"name": model_name}
|
||||||
)
|
}
|
||||||
|
self.status.log(
|
||||||
|
f"Added default model {model_name} to {provider_name} provider"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback to legacy environment variables
|
# Fallback to legacy environment variables
|
||||||
opencode_model: str | None = os.environ.get("CUBBI_MODEL")
|
opencode_model: str | None = os.environ.get("CUBBI_MODEL")
|
||||||
|
|||||||
208
cubbi/model_fetcher.py
Normal file
208
cubbi/model_fetcher.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""
|
||||||
|
Model fetching utilities for OpenAI-compatible providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFetcher:
|
||||||
|
"""Fetches model lists from OpenAI-compatible API endpoints."""
|
||||||
|
|
||||||
|
def __init__(self, timeout: int = 30):
|
||||||
|
"""Initialize the model fetcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
"""
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def fetch_models(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""Fetch models from an OpenAI-compatible /v1/models endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Base URL of the provider (e.g., "https://api.openai.com" or "https://api.litellm.com")
|
||||||
|
api_key: Optional API key for authentication
|
||||||
|
headers: Optional additional headers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model dictionaries with 'id' and 'name' keys
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.RequestException: If the request fails
|
||||||
|
ValueError: If the response format is invalid
|
||||||
|
"""
|
||||||
|
# Construct the models endpoint URL
|
||||||
|
models_url = self._build_models_url(base_url)
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
request_headers = self._build_headers(api_key, headers)
|
||||||
|
|
||||||
|
logger.info(f"Fetching models from {models_url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
models_url, headers=request_headers, timeout=self.timeout
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
if not isinstance(data, dict) or "data" not in data:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid response format: expected dict with 'data' key, got {type(data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
models_data = data["data"]
|
||||||
|
if not isinstance(models_data, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid models data: expected list, got {type(models_data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process models
|
||||||
|
models = []
|
||||||
|
for model_item in models_data:
|
||||||
|
if not isinstance(model_item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_id = model_item.get("id", "")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip models with * in their ID as requested
|
||||||
|
if "*" in model_id:
|
||||||
|
logger.debug(f"Skipping model with wildcard: {model_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create model entry
|
||||||
|
model = {
|
||||||
|
"id": model_id,
|
||||||
|
}
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
|
logger.info(f"Successfully fetched {len(models)} models from {base_url}")
|
||||||
|
return models
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
logger.error(f"Request timed out after {self.timeout} seconds")
|
||||||
|
raise requests.RequestException(f"Request to {models_url} timed out")
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
logger.error(f"Connection error: {e}")
|
||||||
|
raise requests.RequestException(f"Failed to connect to {models_url}")
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.error(f"HTTP error {e.response.status_code}: {e}")
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
raise requests.RequestException(
|
||||||
|
"Authentication failed: invalid API key"
|
||||||
|
)
|
||||||
|
elif e.response.status_code == 403:
|
||||||
|
raise requests.RequestException(
|
||||||
|
"Access forbidden: check API key permissions"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise requests.RequestException(
|
||||||
|
f"HTTP {e.response.status_code} error from {models_url}"
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to parse JSON response: {e}")
|
||||||
|
raise ValueError(f"Invalid JSON response from {models_url}")
|
||||||
|
|
||||||
|
def _build_models_url(self, base_url: str) -> str:
|
||||||
|
"""Build the models endpoint URL from a base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Base URL of the provider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete URL for the /v1/models endpoint
|
||||||
|
"""
|
||||||
|
# Remove trailing slash if present
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
|
||||||
|
# Add /v1/models if not already present
|
||||||
|
if not base_url.endswith("/v1/models"):
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
base_url += "/models"
|
||||||
|
else:
|
||||||
|
base_url += "/v1/models"
|
||||||
|
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
def _build_headers(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
additional_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Build request headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Optional API key for authentication
|
||||||
|
additional_headers: Optional additional headers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of headers
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add authentication header if API key is provided
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
# Add any additional headers
|
||||||
|
if additional_headers:
|
||||||
|
headers.update(additional_headers)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_provider_models(
|
||||||
|
provider_config: Dict, timeout: int = 30
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""Convenience function to fetch models for a provider configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_config: Provider configuration dictionary
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model dictionaries
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is not OpenAI-compatible or missing required fields
|
||||||
|
requests.RequestException: If the request fails
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
provider_type = provider_config.get("type", "")
|
||||||
|
base_url = provider_config.get("base_url")
|
||||||
|
api_key = provider_config.get("api_key", "")
|
||||||
|
|
||||||
|
if api_key.startswith("${") and api_key.endswith("}"):
|
||||||
|
env_var_name = api_key[2:-1]
|
||||||
|
api_key = os.environ.get(env_var_name, "")
|
||||||
|
|
||||||
|
if provider_type != "openai" and not base_url:
|
||||||
|
raise ValueError(
|
||||||
|
"Provider is not OpenAI-compatible (must have type='openai' or base_url)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("No base_url specified for OpenAI-compatible provider")
|
||||||
|
|
||||||
|
fetcher = ModelFetcher(timeout=timeout)
|
||||||
|
return fetcher.fetch_models(base_url, api_key)
|
||||||
@@ -635,3 +635,113 @@ class UserConfigManager:
|
|||||||
self.set("defaults.ports", ports)
|
self.set("defaults.ports", ports)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Model management methods
|
||||||
|
def list_provider_models(self, provider_name: str) -> List[Dict[str, str]]:
|
||||||
|
"""Get all models for a specific provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model dictionaries with 'id' and 'name' keys
|
||||||
|
"""
|
||||||
|
provider_config = self.get_provider(provider_name)
|
||||||
|
if not provider_config:
|
||||||
|
return []
|
||||||
|
|
||||||
|
models = provider_config.get("models", [])
|
||||||
|
normalized_models = []
|
||||||
|
for model in models:
|
||||||
|
if isinstance(model, str):
|
||||||
|
normalized_models.append({"id": model})
|
||||||
|
elif isinstance(model, dict):
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
if model_id:
|
||||||
|
normalized_models.append({"id": model_id})
|
||||||
|
|
||||||
|
return normalized_models
|
||||||
|
|
||||||
|
def set_provider_models(
|
||||||
|
self, provider_name: str, models: List[Dict[str, str]]
|
||||||
|
) -> None:
|
||||||
|
"""Set the models for a specific provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider
|
||||||
|
models: List of model dictionaries with 'id' and optional 'name' keys
|
||||||
|
"""
|
||||||
|
provider_config = self.get_provider(provider_name)
|
||||||
|
if not provider_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normalize models - ensure each has id, name defaults to id
|
||||||
|
normalized_models = []
|
||||||
|
for model in models:
|
||||||
|
if isinstance(model, dict) and "id" in model:
|
||||||
|
normalized_model = {
|
||||||
|
"id": model["id"],
|
||||||
|
}
|
||||||
|
normalized_models.append(normalized_model)
|
||||||
|
|
||||||
|
provider_config["models"] = normalized_models
|
||||||
|
self.set(f"providers.{provider_name}", provider_config)
|
||||||
|
|
||||||
|
def add_provider_model(
|
||||||
|
self, provider_name: str, model_id: str, model_name: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Add a model to a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider
|
||||||
|
model_id: ID of the model
|
||||||
|
model_name: Optional display name for the model (defaults to model_id)
|
||||||
|
"""
|
||||||
|
models = self.list_provider_models(provider_name)
|
||||||
|
|
||||||
|
for existing_model in models:
|
||||||
|
if existing_model["id"] == model_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_model = {"id": model_id}
|
||||||
|
models.append(new_model)
|
||||||
|
self.set_provider_models(provider_name, models)
|
||||||
|
|
||||||
|
def remove_provider_model(self, provider_name: str, model_id: str) -> bool:
|
||||||
|
"""Remove a model from a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider
|
||||||
|
model_id: ID of the model to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model was removed, False if it didn't exist
|
||||||
|
"""
|
||||||
|
models = self.list_provider_models(provider_name)
|
||||||
|
original_length = len(models)
|
||||||
|
|
||||||
|
# Filter out the model with the specified ID
|
||||||
|
models = [model for model in models if model["id"] != model_id]
|
||||||
|
|
||||||
|
if len(models) < original_length:
|
||||||
|
self.set_provider_models(provider_name, models)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_provider_openai_compatible(self, provider_name: str) -> bool:
|
||||||
|
provider_config = self.get_provider(provider_name)
|
||||||
|
if not provider_config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
provider_type = provider_config.get("type", "")
|
||||||
|
return provider_type == "openai" and provider_config.get("base_url") is not None
|
||||||
|
|
||||||
|
def list_openai_compatible_providers(self) -> List[str]:
|
||||||
|
providers = self.list_providers()
|
||||||
|
compatible_providers = []
|
||||||
|
|
||||||
|
for provider_name in providers.keys():
|
||||||
|
if self.is_provider_openai_compatible(provider_name):
|
||||||
|
compatible_providers.append(provider_name)
|
||||||
|
|
||||||
|
return compatible_providers
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ dependencies = [
|
|||||||
"rich>=13.6.0",
|
"rich>=13.6.0",
|
||||||
"pydantic>=2.5.0",
|
"pydantic>=2.5.0",
|
||||||
"questionary>=2.0.0",
|
"questionary>=2.0.0",
|
||||||
|
"requests>=2.32.3",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -85,6 +85,7 @@ dependencies = [
|
|||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pyyaml" },
|
{ name = "pyyaml" },
|
||||||
{ name = "questionary" },
|
{ name = "questionary" },
|
||||||
|
{ name = "requests" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "typer" },
|
{ name = "typer" },
|
||||||
]
|
]
|
||||||
@@ -109,6 +110,7 @@ requires-dist = [
|
|||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
||||||
{ name = "pyyaml", specifier = ">=6.0.1" },
|
{ name = "pyyaml", specifier = ">=6.0.1" },
|
||||||
{ name = "questionary", specifier = ">=2.0.0" },
|
{ name = "questionary", specifier = ">=2.0.0" },
|
||||||
|
{ name = "requests", specifier = ">=2.32.3" },
|
||||||
{ name = "rich", specifier = ">=13.6.0" },
|
{ name = "rich", specifier = ">=13.6.0" },
|
||||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.9" },
|
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.9" },
|
||||||
{ name = "typer", specifier = ">=0.9.0" },
|
{ name = "typer", specifier = ">=0.9.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user