feat: support ssh commands (#10)
This commit is contained in:
@@ -19,6 +19,7 @@ type Config struct {
|
||||
Network NetworkConfig `json:"network"`
|
||||
Filesystem FilesystemConfig `json:"filesystem"`
|
||||
Command CommandConfig `json:"command"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
AllowPty bool `json:"allowPty,omitempty"`
|
||||
}
|
||||
|
||||
@@ -49,6 +50,17 @@ type CommandConfig struct {
|
||||
UseDefaults *bool `json:"useDefaults,omitempty"`
|
||||
}
|
||||
|
||||
// SSHConfig defines SSH command restrictions.
|
||||
// SSH commands are filtered using an allowlist by default for security.
|
||||
type SSHConfig struct {
|
||||
AllowedHosts []string `json:"allowedHosts"` // Host patterns to allow SSH to (supports wildcards like *.example.com)
|
||||
DeniedHosts []string `json:"deniedHosts"` // Host patterns to deny SSH to (checked before allowed)
|
||||
AllowedCommands []string `json:"allowedCommands"` // Commands allowed over SSH (allowlist mode)
|
||||
DeniedCommands []string `json:"deniedCommands"` // Commands denied over SSH (checked before allowed)
|
||||
AllowAllCommands bool `json:"allowAllCommands,omitempty"` // If true, use denylist mode instead of allowlist
|
||||
InheritDeny bool `json:"inheritDeny,omitempty"` // If true, also apply global command.deny rules
|
||||
}
|
||||
|
||||
// DefaultDeniedCommands returns commands that are blocked by default.
|
||||
// These are system-level dangerous commands that are rarely needed by AI agents.
|
||||
var DefaultDeniedCommands = []string{
|
||||
@@ -109,6 +121,12 @@ func Default() *Config {
|
||||
Allow: []string{},
|
||||
// UseDefaults defaults to true (nil = true)
|
||||
},
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{},
|
||||
DeniedHosts: []string{},
|
||||
AllowedCommands: []string{},
|
||||
DeniedCommands: []string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,6 +196,24 @@ func (c *Config) Validate() error {
|
||||
return errors.New("command.allow contains empty command")
|
||||
}
|
||||
|
||||
// SSH config
|
||||
for _, host := range c.SSH.AllowedHosts {
|
||||
if err := validateHostPattern(host); err != nil {
|
||||
return fmt.Errorf("invalid ssh.allowedHosts %q: %w", host, err)
|
||||
}
|
||||
}
|
||||
for _, host := range c.SSH.DeniedHosts {
|
||||
if err := validateHostPattern(host); err != nil {
|
||||
return fmt.Errorf("invalid ssh.deniedHosts %q: %w", host, err)
|
||||
}
|
||||
}
|
||||
if slices.Contains(c.SSH.AllowedCommands, "") {
|
||||
return errors.New("ssh.allowedCommands contains empty command")
|
||||
}
|
||||
if slices.Contains(c.SSH.DeniedCommands, "") {
|
||||
return errors.New("ssh.deniedCommands contains empty command")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -229,6 +265,42 @@ func validateDomainPattern(pattern string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHostPattern validates an SSH host pattern.
|
||||
// Host patterns are more permissive than domain patterns:
|
||||
// - Can contain wildcards anywhere (e.g., prod-*.example.com, *.example.com)
|
||||
// - Can be IP addresses
|
||||
// - Can be simple hostnames without dots
|
||||
func validateHostPattern(pattern string) error {
|
||||
if pattern == "" {
|
||||
return errors.New("empty host pattern")
|
||||
}
|
||||
|
||||
// Reject patterns with protocol or path
|
||||
if strings.Contains(pattern, "://") || strings.Contains(pattern, "/") {
|
||||
return errors.New("host pattern cannot contain protocol or path")
|
||||
}
|
||||
|
||||
// Reject patterns with port (user@host:port style)
|
||||
// But allow colons for IPv6 addresses
|
||||
if strings.Contains(pattern, ":") && !strings.Contains(pattern, "::") && !isIPv6Pattern(pattern) {
|
||||
return errors.New("host pattern cannot contain port; specify port in SSH command instead")
|
||||
}
|
||||
|
||||
// Reject patterns with @ (should be just the host, not user@host)
|
||||
if strings.Contains(pattern, "@") {
|
||||
return errors.New("host pattern should not contain username; specify just the host")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isIPv6Pattern checks if a pattern looks like an IPv6 address.
|
||||
func isIPv6Pattern(pattern string) bool {
|
||||
// IPv6 addresses contain multiple colons
|
||||
colonCount := strings.Count(pattern, ":")
|
||||
return colonCount >= 2
|
||||
}
|
||||
|
||||
// MatchesDomain checks if a hostname matches a domain pattern.
|
||||
func MatchesDomain(hostname, pattern string) bool {
|
||||
hostname = strings.ToLower(hostname)
|
||||
@@ -249,6 +321,71 @@ func MatchesDomain(hostname, pattern string) bool {
|
||||
return hostname == pattern
|
||||
}
|
||||
|
||||
// MatchesHost checks if a hostname matches an SSH host pattern.
|
||||
// SSH host patterns support wildcards anywhere in the pattern.
|
||||
func MatchesHost(hostname, pattern string) bool {
|
||||
hostname = strings.ToLower(hostname)
|
||||
pattern = strings.ToLower(pattern)
|
||||
|
||||
// "*" matches all hosts
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If pattern contains no wildcards, do exact match
|
||||
if !strings.Contains(pattern, "*") {
|
||||
return hostname == pattern
|
||||
}
|
||||
|
||||
// Convert glob pattern to a simple matcher
|
||||
// Split pattern by * and check each part
|
||||
return matchGlob(hostname, pattern)
|
||||
}
|
||||
|
||||
// matchGlob performs simple glob matching with * wildcards.
|
||||
func matchGlob(s, pattern string) bool {
|
||||
// Handle edge cases
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if pattern == "" {
|
||||
return s == ""
|
||||
}
|
||||
|
||||
// Split pattern by * and match parts
|
||||
parts := strings.Split(pattern, "*")
|
||||
|
||||
// Check prefix (before first *)
|
||||
if !strings.HasPrefix(s, parts[0]) {
|
||||
return false
|
||||
}
|
||||
s = s[len(parts[0]):]
|
||||
|
||||
// Check suffix (after last *)
|
||||
if len(parts) > 1 {
|
||||
last := parts[len(parts)-1]
|
||||
if !strings.HasSuffix(s, last) {
|
||||
return false
|
||||
}
|
||||
s = s[:len(s)-len(last)]
|
||||
}
|
||||
|
||||
// Check middle parts (between *s)
|
||||
for i := 1; i < len(parts)-1; i++ {
|
||||
part := parts[i]
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
idx := strings.Index(s, part)
|
||||
if idx < 0 {
|
||||
return false
|
||||
}
|
||||
s = s[idx+len(part):]
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Merge combines a base config with an override config.
|
||||
// Values in override take precedence. Slice fields are appended (base + override).
|
||||
// The Extends field is cleared in the result since inheritance has been resolved.
|
||||
@@ -307,6 +444,18 @@ func Merge(base, override *Config) *Config {
|
||||
// Pointer field: override wins if set
|
||||
UseDefaults: mergeOptionalBool(base.Command.UseDefaults, override.Command.UseDefaults),
|
||||
},
|
||||
|
||||
SSH: SSHConfig{
|
||||
// Append slices
|
||||
AllowedHosts: mergeStrings(base.SSH.AllowedHosts, override.SSH.AllowedHosts),
|
||||
DeniedHosts: mergeStrings(base.SSH.DeniedHosts, override.SSH.DeniedHosts),
|
||||
AllowedCommands: mergeStrings(base.SSH.AllowedCommands, override.SSH.AllowedCommands),
|
||||
DeniedCommands: mergeStrings(base.SSH.DeniedCommands, override.SSH.DeniedCommands),
|
||||
|
||||
// Boolean fields: true if either enables it
|
||||
AllowAllCommands: base.SSH.AllowAllCommands || override.SSH.AllowAllCommands,
|
||||
InheritDeny: base.SSH.InheritDeny || override.SSH.InheritDeny,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -480,3 +480,210 @@ func TestMerge(t *testing.T) {
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func TestValidateHostPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid patterns
|
||||
{"simple hostname", "server1", false},
|
||||
{"domain", "example.com", false},
|
||||
{"subdomain", "prod.example.com", false},
|
||||
{"wildcard prefix", "*.example.com", false},
|
||||
{"wildcard middle", "prod-*.example.com", false},
|
||||
{"ip address", "192.168.1.1", false},
|
||||
{"ipv6 address", "::1", false},
|
||||
{"ipv6 full", "2001:db8::1", false},
|
||||
{"localhost", "localhost", false},
|
||||
|
||||
// Invalid patterns
|
||||
{"empty", "", true},
|
||||
{"with protocol", "ssh://example.com", true},
|
||||
{"with path", "example.com/path", true},
|
||||
{"with port", "example.com:22", true},
|
||||
{"with username", "user@example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateHostPattern(tt.pattern)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateHostPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
pattern string
|
||||
want bool
|
||||
}{
|
||||
// Exact matches
|
||||
{"exact match", "server1.example.com", "server1.example.com", true},
|
||||
{"exact match case insensitive", "Server1.Example.COM", "server1.example.com", true},
|
||||
{"exact no match", "server2.example.com", "server1.example.com", false},
|
||||
|
||||
// Wildcard matches
|
||||
{"wildcard prefix", "api.example.com", "*.example.com", true},
|
||||
{"wildcard prefix deep", "deep.api.example.com", "*.example.com", true},
|
||||
{"wildcard no match base", "example.com", "*.example.com", false},
|
||||
{"wildcard middle", "prod-web-01.example.com", "prod-*.example.com", true},
|
||||
{"wildcard middle no match", "dev-web-01.example.com", "prod-*.example.com", false},
|
||||
{"wildcard suffix", "server1.prod", "server1.*", true},
|
||||
{"multiple wildcards", "prod-web-01.us-east.example.com", "prod-*-*.example.com", true},
|
||||
|
||||
// Star matches all
|
||||
{"star matches all", "anything.example.com", "*", true},
|
||||
|
||||
// IP addresses
|
||||
{"ip exact match", "192.168.1.1", "192.168.1.1", true},
|
||||
{"ip no match", "192.168.1.2", "192.168.1.1", false},
|
||||
{"ip wildcard", "192.168.1.100", "192.168.1.*", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MatchesHost(tt.hostname, tt.pattern)
|
||||
if got != tt.want {
|
||||
t.Errorf("MatchesHost(%q, %q) = %v, want %v", tt.hostname, tt.pattern, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid SSH config",
|
||||
config: Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com", "prod-*.internal"},
|
||||
AllowedCommands: []string{"ls", "cat", "grep"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid allowed host with protocol",
|
||||
config: Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{"ssh://example.com"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid denied host with username",
|
||||
config: Config{
|
||||
SSH: SSHConfig{
|
||||
DeniedHosts: []string{"user@example.com"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty allowed command",
|
||||
config: Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{"example.com"},
|
||||
AllowedCommands: []string{"ls", ""},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty denied command",
|
||||
config: Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{"example.com"},
|
||||
DeniedCommands: []string{""},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSSHConfig(t *testing.T) {
|
||||
t.Run("merge SSH allowed hosts", func(t *testing.T) {
|
||||
base := &Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{"prod-*.example.com"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedHosts: []string{"dev-*.example.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.SSH.AllowedHosts) != 2 {
|
||||
t.Errorf("expected 2 allowed hosts, got %d: %v", len(result.SSH.AllowedHosts), result.SSH.AllowedHosts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge SSH commands", func(t *testing.T) {
|
||||
base := &Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedCommands: []string{"ls", "cat"},
|
||||
DeniedCommands: []string{"rm -rf"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
SSH: SSHConfig{
|
||||
AllowedCommands: []string{"grep", "find"},
|
||||
DeniedCommands: []string{"shutdown"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.SSH.AllowedCommands) != 4 {
|
||||
t.Errorf("expected 4 allowed commands, got %d", len(result.SSH.AllowedCommands))
|
||||
}
|
||||
if len(result.SSH.DeniedCommands) != 2 {
|
||||
t.Errorf("expected 2 denied commands, got %d", len(result.SSH.DeniedCommands))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge SSH boolean flags", func(t *testing.T) {
|
||||
base := &Config{
|
||||
SSH: SSHConfig{
|
||||
AllowAllCommands: false,
|
||||
InheritDeny: true,
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
SSH: SSHConfig{
|
||||
AllowAllCommands: true,
|
||||
InheritDeny: false,
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if !result.SSH.AllowAllCommands {
|
||||
t.Error("expected AllowAllCommands to be true (OR logic)")
|
||||
}
|
||||
if !result.SSH.InheritDeny {
|
||||
t.Error("expected InheritDeny to be true (OR logic)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user