feat: support ssh commands (#10)

This commit is contained in:
JY Tan
2026-01-17 15:36:51 -08:00
committed by GitHub
parent 3c3f28b32c
commit 20fa647ccc
6 changed files with 1045 additions and 0 deletions

View File

@@ -480,3 +480,210 @@ func TestMerge(t *testing.T) {
func boolPtr(b bool) *bool {
return &b
}
func TestValidateHostPattern(t *testing.T) {
tests := []struct {
name string
pattern string
wantErr bool
}{
// Valid patterns
{"simple hostname", "server1", false},
{"domain", "example.com", false},
{"subdomain", "prod.example.com", false},
{"wildcard prefix", "*.example.com", false},
{"wildcard middle", "prod-*.example.com", false},
{"ip address", "192.168.1.1", false},
{"ipv6 address", "::1", false},
{"ipv6 full", "2001:db8::1", false},
{"localhost", "localhost", false},
// Invalid patterns
{"empty", "", true},
{"with protocol", "ssh://example.com", true},
{"with path", "example.com/path", true},
{"with port", "example.com:22", true},
{"with username", "user@example.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateHostPattern(tt.pattern)
if (err != nil) != tt.wantErr {
t.Errorf("validateHostPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr)
}
})
}
}
func TestMatchesHost(t *testing.T) {
tests := []struct {
name string
hostname string
pattern string
want bool
}{
// Exact matches
{"exact match", "server1.example.com", "server1.example.com", true},
{"exact match case insensitive", "Server1.Example.COM", "server1.example.com", true},
{"exact no match", "server2.example.com", "server1.example.com", false},
// Wildcard matches
{"wildcard prefix", "api.example.com", "*.example.com", true},
{"wildcard prefix deep", "deep.api.example.com", "*.example.com", true},
{"wildcard no match base", "example.com", "*.example.com", false},
{"wildcard middle", "prod-web-01.example.com", "prod-*.example.com", true},
{"wildcard middle no match", "dev-web-01.example.com", "prod-*.example.com", false},
{"wildcard suffix", "server1.prod", "server1.*", true},
{"multiple wildcards", "prod-web-01.us-east.example.com", "prod-*-*.example.com", true},
// Star matches all
{"star matches all", "anything.example.com", "*", true},
// IP addresses
{"ip exact match", "192.168.1.1", "192.168.1.1", true},
{"ip no match", "192.168.1.2", "192.168.1.1", false},
{"ip wildcard", "192.168.1.100", "192.168.1.*", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MatchesHost(tt.hostname, tt.pattern)
if got != tt.want {
t.Errorf("MatchesHost(%q, %q) = %v, want %v", tt.hostname, tt.pattern, got, tt.want)
}
})
}
}
func TestSSHConfigValidation(t *testing.T) {
tests := []struct {
name string
config Config
wantErr bool
}{
{
name: "valid SSH config",
config: Config{
SSH: SSHConfig{
AllowedHosts: []string{"*.example.com", "prod-*.internal"},
AllowedCommands: []string{"ls", "cat", "grep"},
},
},
wantErr: false,
},
{
name: "invalid allowed host with protocol",
config: Config{
SSH: SSHConfig{
AllowedHosts: []string{"ssh://example.com"},
},
},
wantErr: true,
},
{
name: "invalid denied host with username",
config: Config{
SSH: SSHConfig{
DeniedHosts: []string{"user@example.com"},
},
},
wantErr: true,
},
{
name: "empty allowed command",
config: Config{
SSH: SSHConfig{
AllowedHosts: []string{"example.com"},
AllowedCommands: []string{"ls", ""},
},
},
wantErr: true,
},
{
name: "empty denied command",
config: Config{
SSH: SSHConfig{
AllowedHosts: []string{"example.com"},
DeniedCommands: []string{""},
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestMergeSSHConfig(t *testing.T) {
t.Run("merge SSH allowed hosts", func(t *testing.T) {
base := &Config{
SSH: SSHConfig{
AllowedHosts: []string{"prod-*.example.com"},
},
}
override := &Config{
SSH: SSHConfig{
AllowedHosts: []string{"dev-*.example.com"},
},
}
result := Merge(base, override)
if len(result.SSH.AllowedHosts) != 2 {
t.Errorf("expected 2 allowed hosts, got %d: %v", len(result.SSH.AllowedHosts), result.SSH.AllowedHosts)
}
})
t.Run("merge SSH commands", func(t *testing.T) {
base := &Config{
SSH: SSHConfig{
AllowedCommands: []string{"ls", "cat"},
DeniedCommands: []string{"rm -rf"},
},
}
override := &Config{
SSH: SSHConfig{
AllowedCommands: []string{"grep", "find"},
DeniedCommands: []string{"shutdown"},
},
}
result := Merge(base, override)
if len(result.SSH.AllowedCommands) != 4 {
t.Errorf("expected 4 allowed commands, got %d", len(result.SSH.AllowedCommands))
}
if len(result.SSH.DeniedCommands) != 2 {
t.Errorf("expected 2 denied commands, got %d", len(result.SSH.DeniedCommands))
}
})
t.Run("merge SSH boolean flags", func(t *testing.T) {
base := &Config{
SSH: SSHConfig{
AllowAllCommands: false,
InheritDeny: true,
},
}
override := &Config{
SSH: SSHConfig{
AllowAllCommands: true,
InheritDeny: false,
},
}
result := Merge(base, override)
if !result.SSH.AllowAllCommands {
t.Error("expected AllowAllCommands to be true (OR logic)")
}
if !result.SSH.InheritDeny {
t.Error("expected InheritDeny to be true (OR logic)")
}
})
}