feat: add configuration override in session create with --config/-c

This commit is contained in:
2025-07-10 14:51:33 -06:00
parent da5937e708
commit 672b8a8e31

View File

@@ -173,6 +173,12 @@ def create_session(
None, "--provider", "-p", help="Provider to use" None, "--provider", "-p", help="Provider to use"
), ),
ssh: bool = typer.Option(False, "--ssh", help="Start SSH server in the container"), ssh: bool = typer.Option(False, "--ssh", help="Start SSH server in the container"),
config: List[str] = typer.Option(
[],
"--config",
"-c",
help="Override configuration values (KEY=VALUE) for this session only",
),
verbose: bool = typer.Option(False, "--verbose", help="Enable verbose logging"), verbose: bool = typer.Option(False, "--verbose", help="Enable verbose logging"),
) -> None: ) -> None:
"""Create a new Cubbi session """Create a new Cubbi session
@@ -189,16 +195,67 @@ def create_session(
target_gid = gid if gid is not None else os.getgid() target_gid = gid if gid is not None else os.getgid()
console.print(f"Using UID: {target_uid}, GID: {target_gid}") console.print(f"Using UID: {target_uid}, GID: {target_gid}")
# Use default image from user configuration # Create a temporary user config manager with overrides
temp_user_config = UserConfigManager()
# Parse and apply config overrides
config_overrides = {}
for config_item in config:
if "=" in config_item:
key, value = config_item.split("=", 1)
# Convert string value to appropriate type
if value.lower() == "true":
typed_value = True
elif value.lower() == "false":
typed_value = False
elif value.isdigit():
typed_value = int(value)
else:
typed_value = value
config_overrides[key] = typed_value
console.print(f"[blue]Config override: {key} = {typed_value}[/blue]")
else:
console.print(
f"[yellow]Warning: Ignoring invalid config format: {config_item}. Use KEY=VALUE.[/yellow]"
)
# Apply overrides to temp config (without saving)
for key, value in config_overrides.items():
# Handle shorthand service paths (e.g., "langfuse.url")
if (
"." in key
and not key.startswith("services.")
and not any(
key.startswith(section + ".")
for section in ["defaults", "docker", "remote", "ui"]
)
):
service, setting = key.split(".", 1)
key = f"services.{service}.{setting}"
# Split the key path and navigate to set the value
parts = key.split(".")
config_dict = temp_user_config.config
# Navigate to the containing dictionary
for part in parts[:-1]:
if part not in config_dict:
config_dict[part] = {}
config_dict = config_dict[part]
# Set the value without saving
config_dict[parts[-1]] = value
# Use default image from user configuration (with overrides applied)
if not image: if not image:
image_name = user_config.get( image_name = temp_user_config.get(
"defaults.image", config_manager.config.defaults.get("image", "goose") "defaults.image", config_manager.config.defaults.get("image", "goose")
) )
else: else:
image_name = image image_name = image
# Start with environment variables from user configuration # Start with environment variables from user configuration (with overrides applied)
environment = user_config.get_environment_variables() environment = temp_user_config.get_environment_variables()
# Override with environment variables from command line # Override with environment variables from command line
for var in env: for var in env:
@@ -214,7 +271,7 @@ def create_session(
volume_mounts = {} volume_mounts = {}
# Get default volumes from user config # Get default volumes from user config
default_volumes = user_config.get("defaults.volumes", []) default_volumes = temp_user_config.get("defaults.volumes", [])
# Combine default volumes with user-specified volumes # Combine default volumes with user-specified volumes
all_volumes = default_volumes + list(volume) all_volumes = default_volumes + list(volume)
@@ -241,7 +298,7 @@ def create_session(
) )
# Get default networks from user config # Get default networks from user config
default_networks = user_config.get("defaults.networks", []) default_networks = temp_user_config.get("defaults.networks", [])
# Combine default networks with user-specified networks, removing duplicates # Combine default networks with user-specified networks, removing duplicates
all_networks = list(set(default_networks + network)) all_networks = list(set(default_networks + network))
@@ -249,7 +306,7 @@ def create_session(
# Get default MCPs from user config if none specified # Get default MCPs from user config if none specified
all_mcps = mcp if isinstance(mcp, list) else [] all_mcps = mcp if isinstance(mcp, list) else []
if not all_mcps: if not all_mcps:
default_mcps = user_config.get("defaults.mcps", []) default_mcps = temp_user_config.get("defaults.mcps", [])
all_mcps = default_mcps all_mcps = default_mcps
if default_mcps: if default_mcps:
@@ -277,6 +334,16 @@ def create_session(
"[yellow]Warning: --no-shell is ignored without --run[/yellow]" "[yellow]Warning: --no-shell is ignored without --run[/yellow]"
) )
# Use model and provider from config overrides if not explicitly provided
final_model = (
model if model is not None else temp_user_config.get("defaults.model")
)
final_provider = (
provider
if provider is not None
else temp_user_config.get("defaults.provider")
)
session = container_manager.create_session( session = container_manager.create_session(
image_name=image_name, image_name=image_name,
project=path_or_url, project=path_or_url,
@@ -292,8 +359,8 @@ def create_session(
uid=target_uid, uid=target_uid,
gid=target_gid, gid=target_gid,
ssh=ssh, ssh=ssh,
model=model, model=final_model,
provider=provider, provider=final_provider,
) )
if session: if session:
@@ -307,7 +374,7 @@ def create_session(
console.print(f" {container_port} -> {host_port}") console.print(f" {container_port} -> {host_port}")
# Auto-connect based on user config, unless overridden by --no-connect flag or --no-shell # Auto-connect based on user config, unless overridden by --no-connect flag or --no-shell
auto_connect = user_config.get("defaults.connect", True) auto_connect = temp_user_config.get("defaults.connect", True)
# When --no-shell is used with --run, show logs instead of connecting # When --no-shell is used with --run, show logs instead of connecting
if no_shell and run_command: if no_shell and run_command: