feat: support ssh commands (#10)
This commit is contained in:
@@ -454,3 +454,373 @@ func TestMatchesPrefix(t *testing.T) {
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func TestParseSSHCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
command string
|
||||
wantHost string
|
||||
wantCmd string
|
||||
wantIsSSH bool
|
||||
desc string
|
||||
}{
|
||||
// Basic SSH commands
|
||||
{`ssh server1.example.com`, "server1.example.com", "", true, "simple host"},
|
||||
{`ssh user@server1.example.com`, "server1.example.com", "", true, "user@host"},
|
||||
{`ssh server1.example.com ls -la`, "server1.example.com", "ls -la", true, "host with command"},
|
||||
{`ssh user@server1.example.com "cat /var/log/app.log"`, "server1.example.com", `cat /var/log/app.log`, true, "user@host with quoted command"},
|
||||
|
||||
// SSH with options
|
||||
{`ssh -p 2222 server1.example.com`, "server1.example.com", "", true, "with port option"},
|
||||
{`ssh -i ~/.ssh/key server1.example.com ls`, "server1.example.com", "ls", true, "with identity file"},
|
||||
{`ssh -v -t server1.example.com`, "server1.example.com", "", true, "with flags"},
|
||||
{`ssh -o StrictHostKeyChecking=no server1.example.com`, "server1.example.com", "", true, "with -o option"},
|
||||
|
||||
// Full path to ssh
|
||||
{`/usr/bin/ssh server1.example.com ls`, "server1.example.com", "ls", true, "full path ssh"},
|
||||
|
||||
// Not SSH commands
|
||||
{`ls -la`, "", "", false, "not ssh"},
|
||||
{`sshpass -p password ssh server`, "", "", false, "sshpass wrapper"},
|
||||
{`echo ssh server`, "", "", false, "ssh as argument"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
host, cmd, isSSH := parseSSHCommand(tt.command)
|
||||
if isSSH != tt.wantIsSSH {
|
||||
t.Errorf("parseSSHCommand(%q) isSSH = %v, want %v", tt.command, isSSH, tt.wantIsSSH)
|
||||
}
|
||||
if host != tt.wantHost {
|
||||
t.Errorf("parseSSHCommand(%q) host = %q, want %q", tt.command, host, tt.wantHost)
|
||||
}
|
||||
if cmd != tt.wantCmd {
|
||||
t.Errorf("parseSSHCommand(%q) cmd = %q, want %q", tt.command, cmd, tt.wantCmd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_HostPolicy(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com", "prod-*"},
|
||||
DeniedHosts: []string{"prod-db.example.com"},
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Allowed hosts
|
||||
{`ssh server1.example.com`, false, "allowed by wildcard"},
|
||||
{`ssh api.example.com`, false, "allowed subdomain"},
|
||||
{`ssh prod-web-01`, false, "allowed by prod-* pattern"},
|
||||
|
||||
// Denied hosts
|
||||
{`ssh prod-db.example.com`, true, "explicitly denied"},
|
||||
|
||||
// Not in allowlist
|
||||
{`ssh other.domain.com`, true, "not in allowedHosts"},
|
||||
{`ssh dev-server`, true, "not matching any pattern"},
|
||||
|
||||
// Non-SSH commands (should pass through)
|
||||
{`ls -la`, false, "not an SSH command"},
|
||||
{`curl https://example.com`, false, "not an SSH command"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_AllowlistMode(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com"},
|
||||
AllowedCommands: []string{"ls", "cat", "grep", "tail -f"},
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Allowed commands
|
||||
{`ssh server.example.com ls`, false, "ls allowed"},
|
||||
{`ssh server.example.com ls -la /var/log`, false, "ls with args"},
|
||||
{`ssh server.example.com cat /etc/hosts`, false, "cat allowed"},
|
||||
{`ssh server.example.com grep error /var/log/app.log`, false, "grep allowed"},
|
||||
{`ssh server.example.com tail -f /var/log/app.log`, false, "tail -f allowed"},
|
||||
|
||||
// Not in allowlist
|
||||
{`ssh server.example.com rm -rf /tmp/cache`, true, "rm not in allowlist"},
|
||||
{`ssh server.example.com chmod 777 /tmp`, true, "chmod not in allowlist"},
|
||||
{`ssh server.example.com shutdown now`, true, "shutdown not in allowlist"},
|
||||
{`ssh server.example.com tail /var/log/app.log`, true, "tail without -f not allowed"},
|
||||
|
||||
// Interactive session (no command) - should be allowed
|
||||
{`ssh server.example.com`, false, "interactive session allowed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_DenylistMode(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com"},
|
||||
AllowAllCommands: true, // denylist mode
|
||||
DeniedCommands: []string{"rm -rf", "shutdown", "chmod"},
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Allowed (not in denylist)
|
||||
{`ssh server.example.com ls -la`, false, "ls allowed"},
|
||||
{`ssh server.example.com cat /etc/hosts`, false, "cat allowed"},
|
||||
{`ssh server.example.com rm file.txt`, false, "rm single file allowed"},
|
||||
{`ssh server.example.com apt-get update`, false, "apt-get allowed"},
|
||||
|
||||
// Denied
|
||||
{`ssh server.example.com rm -rf /tmp/cache`, true, "rm -rf denied"},
|
||||
{`ssh server.example.com shutdown now`, true, "shutdown denied"},
|
||||
{`ssh server.example.com chmod 777 /tmp`, true, "chmod denied"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_InheritDeny(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com"},
|
||||
AllowAllCommands: true, // denylist mode
|
||||
InheritDeny: true, // inherit global denies
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
Deny: []string{"git push", "npm publish"},
|
||||
UseDefaults: boolPtr(true), // include default denies like shutdown
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Inherited from global deny
|
||||
{`ssh server.example.com git push origin main`, true, "git push from global deny"},
|
||||
{`ssh server.example.com npm publish`, true, "npm publish from global deny"},
|
||||
|
||||
// Inherited from default deny list
|
||||
{`ssh server.example.com shutdown now`, true, "shutdown from default deny"},
|
||||
{`ssh server.example.com reboot`, true, "reboot from default deny"},
|
||||
|
||||
// Allowed (not in any deny list)
|
||||
{`ssh server.example.com ls -la`, false, "ls allowed"},
|
||||
{`ssh server.example.com git status`, false, "git status allowed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_NoSSHConfig(t *testing.T) {
|
||||
// No SSH policy configured - all SSH commands should pass through
|
||||
cfg := &config.Config{
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
desc string
|
||||
}{
|
||||
{`ssh server.example.com rm -rf /`, "dangerous command allowed when no SSH policy"},
|
||||
{`ssh any-host.com shutdown`, "any host allowed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed (no SSH policy), got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCommand_IntegratesSSH(t *testing.T) {
|
||||
// Test that CheckCommand also checks SSH policies
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com"},
|
||||
AllowedCommands: []string{"ls", "cat"},
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Via CheckCommand, SSH policy should be enforced
|
||||
{`ssh server.example.com ls`, false, "allowed SSH command"},
|
||||
{`ssh server.example.com rm -rf /`, true, "blocked SSH command"},
|
||||
{`ssh other.com ls`, true, "blocked host"},
|
||||
|
||||
// Non-SSH commands unaffected
|
||||
{`ls -la`, false, "local ls allowed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_CommandChaining(t *testing.T) {
|
||||
// Test that command chaining doesn't bypass allow/deny rules
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com"},
|
||||
AllowedCommands: []string{"ls", "cat", "git status"},
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Chaining should NOT bypass allowlist
|
||||
{`ssh server.example.com "ls && rm -rf /"`, true, "ls allowed but rm -rf not"},
|
||||
{`ssh server.example.com "git status && rm -rf /"`, true, "git status allowed but rm -rf not"},
|
||||
{`ssh server.example.com "cat file; shutdown"`, true, "cat allowed but shutdown not"},
|
||||
{`ssh server.example.com "ls | xargs rm"`, true, "ls allowed but rm not"},
|
||||
{`ssh server.example.com "ls || rm -rf /"`, true, "ls allowed but rm -rf not"},
|
||||
|
||||
// All subcommands allowed should work
|
||||
{`ssh server.example.com "ls && cat file"`, false, "both ls and cat allowed"},
|
||||
{`ssh server.example.com "ls; cat file"`, false, "semicolon chain with allowed commands"},
|
||||
{`ssh server.example.com "ls | cat"`, false, "pipe with allowed commands"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSSHCommand_CommandChainingDenylist(t *testing.T) {
|
||||
// Test command chaining in denylist mode
|
||||
cfg := &config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
AllowedHosts: []string{"*.example.com"},
|
||||
AllowAllCommands: true,
|
||||
DeniedCommands: []string{"rm -rf", "shutdown"},
|
||||
},
|
||||
Command: config.CommandConfig{
|
||||
UseDefaults: boolPtr(false),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
shouldBlock bool
|
||||
desc string
|
||||
}{
|
||||
// Chaining should still catch denied commands
|
||||
{`ssh server.example.com "ls && rm -rf /"`, true, "rm -rf in chain blocked"},
|
||||
{`ssh server.example.com "cat file; shutdown"`, true, "shutdown in chain blocked"},
|
||||
|
||||
// Chains without denied commands should work
|
||||
{`ssh server.example.com "ls && cat && grep foo"`, false, "chain without denied commands"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
err := CheckSSHCommand(tt.command, cfg)
|
||||
if tt.shouldBlock && err == nil {
|
||||
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||
}
|
||||
if !tt.shouldBlock && err != nil {
|
||||
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user