feat: support ssh commands (#10)

This commit is contained in:
JY Tan
2026-01-17 15:36:51 -08:00
committed by GitHub
parent 3c3f28b32c
commit 20fa647ccc
6 changed files with 1045 additions and 0 deletions

View File

@@ -82,6 +82,11 @@ func checkSingleCommand(command string, cfg *config.Config) error {
}
}
// Check SSH-specific policies if this is an SSH command
if err := CheckSSHCommand(command, cfg); err != nil {
return err
}
return nil
}
@@ -298,3 +303,222 @@ func matchesPrefix(command, prefix string) bool {
return false
}
// SSHBlockedError is returned when an SSH command is blocked by policy.
type SSHBlockedError struct {
Host string
RemoteCommand string
Reason string
}
func (e *SSHBlockedError) Error() string {
if e.RemoteCommand != "" {
return fmt.Sprintf("SSH command blocked: %s (host: %s, command: %s)", e.Reason, e.Host, e.RemoteCommand)
}
return fmt.Sprintf("SSH blocked: %s (host: %s)", e.Reason, e.Host)
}
// CheckSSHCommand checks if an SSH command is allowed by the configuration.
// Returns nil if allowed, or SSHBlockedError if blocked.
func CheckSSHCommand(command string, cfg *config.Config) error {
if cfg == nil {
cfg = config.Default()
}
// Check if SSH config is active (has any hosts configured)
// If no SSH policy is configured, allow by default
if len(cfg.SSH.AllowedHosts) == 0 && len(cfg.SSH.DeniedHosts) == 0 {
return nil
}
host, remoteCmd, isSSH := parseSSHCommand(command)
if !isSSH {
return nil
}
// Check host policy (denied then allowed)
for _, pattern := range cfg.SSH.DeniedHosts {
if config.MatchesHost(host, pattern) {
return &SSHBlockedError{
Host: host,
RemoteCommand: remoteCmd,
Reason: fmt.Sprintf("host matches denied pattern %q", pattern),
}
}
}
hostAllowed := false
for _, pattern := range cfg.SSH.AllowedHosts {
if config.MatchesHost(host, pattern) {
hostAllowed = true
break
}
}
if len(cfg.SSH.AllowedHosts) > 0 && !hostAllowed {
return &SSHBlockedError{
Host: host,
RemoteCommand: remoteCmd,
Reason: "host not in allowedHosts",
}
}
// If no remote command (interactive session), allow if host is allowed
if remoteCmd == "" {
return nil
}
return checkSSHRemoteCommand(remoteCmd, cfg)
}
// checkSSHRemoteCommand checks if a remote command is allowed by SSH policy.
// It parses the remote command into subcommands (handling &&, ||, ;, |) and validates each.
func checkSSHRemoteCommand(remoteCmd string, cfg *config.Config) error {
// Parse into subcommands just like local commands to prevent bypass via chaining
// e.g., "git status && rm -rf /" should check both "git status" and "rm -rf /"
subCommands := parseShellCommand(remoteCmd)
for _, subCmd := range subCommands {
if err := checkSSHSingleCommand(subCmd, remoteCmd, cfg); err != nil {
return err
}
}
return nil
}
// checkSSHSingleCommand checks a single SSH remote command against policy.
func checkSSHSingleCommand(subCmd, fullRemoteCmd string, cfg *config.Config) error {
normalized := normalizeCommand(subCmd)
if normalized == "" {
return nil
}
// Check inherited global deny list first (if enabled)
// User-defined global then default deny list
if cfg.SSH.InheritDeny {
for _, deny := range cfg.Command.Deny {
if matchesPrefix(normalized, deny) {
return &SSHBlockedError{
RemoteCommand: fullRemoteCmd,
Reason: fmt.Sprintf("command %q matches inherited global deny %q", subCmd, deny),
}
}
}
if cfg.Command.UseDefaultDeniedCommands() {
for _, deny := range config.DefaultDeniedCommands {
if matchesPrefix(normalized, deny) {
return &SSHBlockedError{
RemoteCommand: fullRemoteCmd,
Reason: fmt.Sprintf("command %q matches inherited default deny %q", subCmd, deny),
}
}
}
}
}
// Check SSH-specific denied commands
for _, deny := range cfg.SSH.DeniedCommands {
if matchesPrefix(normalized, deny) {
return &SSHBlockedError{
RemoteCommand: fullRemoteCmd,
Reason: fmt.Sprintf("command %q matches ssh.deniedCommands %q", subCmd, deny),
}
}
}
// If allowAllCommands is true, we're in denylist mode - allow anything not denied
if cfg.SSH.AllowAllCommands {
return nil
}
// Allowlist mode: check if command is in allowedCommands
if len(cfg.SSH.AllowedCommands) > 0 {
for _, allow := range cfg.SSH.AllowedCommands {
if matchesPrefix(normalized, allow) {
return nil
}
}
// Not in allowlist
return &SSHBlockedError{
RemoteCommand: fullRemoteCmd,
Reason: fmt.Sprintf("command %q not in ssh.allowedCommands", subCmd),
}
}
// No allowedCommands configured and not in denylist mode = deny all remote commands
return &SSHBlockedError{
RemoteCommand: fullRemoteCmd,
Reason: "no ssh.allowedCommands configured (allowlist mode requires explicit commands)",
}
}
// parseSSHCommand parses an SSH command and extracts the host and remote command.
// Returns (host, remoteCommand, isSSH).
func parseSSHCommand(command string) (string, string, bool) {
command = strings.TrimSpace(command)
if command == "" {
return "", "", false
}
tokens := tokenizeCommand(command)
if len(tokens) == 0 {
return "", "", false
}
cmdName := filepath.Base(tokens[0])
if cmdName != "ssh" {
return "", "", false
}
// Parse SSH arguments to find host and command
// SSH syntax: ssh [options] [user@]hostname [command]
var host string
var remoteCmd string
skipNext := false
for i := 1; i < len(tokens); i++ {
if skipNext {
skipNext = false
continue
}
arg := tokens[i]
// Skip options that take arguments
if arg == "-p" || arg == "-l" || arg == "-i" || arg == "-o" ||
arg == "-F" || arg == "-J" || arg == "-W" || arg == "-b" ||
arg == "-c" || arg == "-D" || arg == "-E" || arg == "-e" ||
arg == "-I" || arg == "-L" || arg == "-m" || arg == "-O" ||
arg == "-Q" || arg == "-R" || arg == "-S" || arg == "-w" {
skipNext = true
continue
}
// Skip single-char options (like -v, -t, -n, etc.)
if strings.HasPrefix(arg, "-") {
continue
}
// First non-option argument is the host
if host == "" {
host = arg
// Extract the hostname from user@host format
if atIdx := strings.LastIndex(host, "@"); atIdx >= 0 {
host = host[atIdx+1:]
}
continue
}
// Remaining arguments form the remote command
remoteCmd = strings.Join(tokens[i:], " ")
break
}
if host == "" {
return "", "", false
}
return host, remoteCmd, true
}

View File

@@ -454,3 +454,373 @@ func TestMatchesPrefix(t *testing.T) {
func boolPtr(b bool) *bool {
return &b
}
func TestParseSSHCommand(t *testing.T) {
tests := []struct {
command string
wantHost string
wantCmd string
wantIsSSH bool
desc string
}{
// Basic SSH commands
{`ssh server1.example.com`, "server1.example.com", "", true, "simple host"},
{`ssh user@server1.example.com`, "server1.example.com", "", true, "user@host"},
{`ssh server1.example.com ls -la`, "server1.example.com", "ls -la", true, "host with command"},
{`ssh user@server1.example.com "cat /var/log/app.log"`, "server1.example.com", `cat /var/log/app.log`, true, "user@host with quoted command"},
// SSH with options
{`ssh -p 2222 server1.example.com`, "server1.example.com", "", true, "with port option"},
{`ssh -i ~/.ssh/key server1.example.com ls`, "server1.example.com", "ls", true, "with identity file"},
{`ssh -v -t server1.example.com`, "server1.example.com", "", true, "with flags"},
{`ssh -o StrictHostKeyChecking=no server1.example.com`, "server1.example.com", "", true, "with -o option"},
// Full path to ssh
{`/usr/bin/ssh server1.example.com ls`, "server1.example.com", "ls", true, "full path ssh"},
// Not SSH commands
{`ls -la`, "", "", false, "not ssh"},
{`sshpass -p password ssh server`, "", "", false, "sshpass wrapper"},
{`echo ssh server`, "", "", false, "ssh as argument"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
host, cmd, isSSH := parseSSHCommand(tt.command)
if isSSH != tt.wantIsSSH {
t.Errorf("parseSSHCommand(%q) isSSH = %v, want %v", tt.command, isSSH, tt.wantIsSSH)
}
if host != tt.wantHost {
t.Errorf("parseSSHCommand(%q) host = %q, want %q", tt.command, host, tt.wantHost)
}
if cmd != tt.wantCmd {
t.Errorf("parseSSHCommand(%q) cmd = %q, want %q", tt.command, cmd, tt.wantCmd)
}
})
}
}
func TestCheckSSHCommand_HostPolicy(t *testing.T) {
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com", "prod-*"},
DeniedHosts: []string{"prod-db.example.com"},
},
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Allowed hosts
{`ssh server1.example.com`, false, "allowed by wildcard"},
{`ssh api.example.com`, false, "allowed subdomain"},
{`ssh prod-web-01`, false, "allowed by prod-* pattern"},
// Denied hosts
{`ssh prod-db.example.com`, true, "explicitly denied"},
// Not in allowlist
{`ssh other.domain.com`, true, "not in allowedHosts"},
{`ssh dev-server`, true, "not matching any pattern"},
// Non-SSH commands (should pass through)
{`ls -la`, false, "not an SSH command"},
{`curl https://example.com`, false, "not an SSH command"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected SSH command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
}
})
}
}
func TestCheckSSHCommand_AllowlistMode(t *testing.T) {
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com"},
AllowedCommands: []string{"ls", "cat", "grep", "tail -f"},
},
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Allowed commands
{`ssh server.example.com ls`, false, "ls allowed"},
{`ssh server.example.com ls -la /var/log`, false, "ls with args"},
{`ssh server.example.com cat /etc/hosts`, false, "cat allowed"},
{`ssh server.example.com grep error /var/log/app.log`, false, "grep allowed"},
{`ssh server.example.com tail -f /var/log/app.log`, false, "tail -f allowed"},
// Not in allowlist
{`ssh server.example.com rm -rf /tmp/cache`, true, "rm not in allowlist"},
{`ssh server.example.com chmod 777 /tmp`, true, "chmod not in allowlist"},
{`ssh server.example.com shutdown now`, true, "shutdown not in allowlist"},
{`ssh server.example.com tail /var/log/app.log`, true, "tail without -f not allowed"},
// Interactive session (no command) - should be allowed
{`ssh server.example.com`, false, "interactive session allowed"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected SSH command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
}
})
}
}
func TestCheckSSHCommand_DenylistMode(t *testing.T) {
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com"},
AllowAllCommands: true, // denylist mode
DeniedCommands: []string{"rm -rf", "shutdown", "chmod"},
},
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Allowed (not in denylist)
{`ssh server.example.com ls -la`, false, "ls allowed"},
{`ssh server.example.com cat /etc/hosts`, false, "cat allowed"},
{`ssh server.example.com rm file.txt`, false, "rm single file allowed"},
{`ssh server.example.com apt-get update`, false, "apt-get allowed"},
// Denied
{`ssh server.example.com rm -rf /tmp/cache`, true, "rm -rf denied"},
{`ssh server.example.com shutdown now`, true, "shutdown denied"},
{`ssh server.example.com chmod 777 /tmp`, true, "chmod denied"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected SSH command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
}
})
}
}
func TestCheckSSHCommand_InheritDeny(t *testing.T) {
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com"},
AllowAllCommands: true, // denylist mode
InheritDeny: true, // inherit global denies
},
Command: config.CommandConfig{
Deny: []string{"git push", "npm publish"},
UseDefaults: boolPtr(true), // include default denies like shutdown
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Inherited from global deny
{`ssh server.example.com git push origin main`, true, "git push from global deny"},
{`ssh server.example.com npm publish`, true, "npm publish from global deny"},
// Inherited from default deny list
{`ssh server.example.com shutdown now`, true, "shutdown from default deny"},
{`ssh server.example.com reboot`, true, "reboot from default deny"},
// Allowed (not in any deny list)
{`ssh server.example.com ls -la`, false, "ls allowed"},
{`ssh server.example.com git status`, false, "git status allowed"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected SSH command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
}
})
}
}
func TestCheckSSHCommand_NoSSHConfig(t *testing.T) {
// No SSH policy configured - all SSH commands should pass through
cfg := &config.Config{
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
desc string
}{
{`ssh server.example.com rm -rf /`, "dangerous command allowed when no SSH policy"},
{`ssh any-host.com shutdown`, "any host allowed"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if err != nil {
t.Errorf("expected SSH command %q to be allowed (no SSH policy), got: %v", tt.command, err)
}
})
}
}
func TestCheckCommand_IntegratesSSH(t *testing.T) {
// Test that CheckCommand also checks SSH policies
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com"},
AllowedCommands: []string{"ls", "cat"},
},
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Via CheckCommand, SSH policy should be enforced
{`ssh server.example.com ls`, false, "allowed SSH command"},
{`ssh server.example.com rm -rf /`, true, "blocked SSH command"},
{`ssh other.com ls`, true, "blocked host"},
// Non-SSH commands unaffected
{`ls -la`, false, "local ls allowed"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected command %q to be allowed, got: %v", tt.command, err)
}
})
}
}
func TestCheckSSHCommand_CommandChaining(t *testing.T) {
// Test that command chaining doesn't bypass allow/deny rules
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com"},
AllowedCommands: []string{"ls", "cat", "git status"},
},
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Chaining should NOT bypass allowlist
{`ssh server.example.com "ls && rm -rf /"`, true, "ls allowed but rm -rf not"},
{`ssh server.example.com "git status && rm -rf /"`, true, "git status allowed but rm -rf not"},
{`ssh server.example.com "cat file; shutdown"`, true, "cat allowed but shutdown not"},
{`ssh server.example.com "ls | xargs rm"`, true, "ls allowed but rm not"},
{`ssh server.example.com "ls || rm -rf /"`, true, "ls allowed but rm -rf not"},
// All subcommands allowed should work
{`ssh server.example.com "ls && cat file"`, false, "both ls and cat allowed"},
{`ssh server.example.com "ls; cat file"`, false, "semicolon chain with allowed commands"},
{`ssh server.example.com "ls | cat"`, false, "pipe with allowed commands"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected SSH command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
}
})
}
}
func TestCheckSSHCommand_CommandChainingDenylist(t *testing.T) {
// Test command chaining in denylist mode
cfg := &config.Config{
SSH: config.SSHConfig{
AllowedHosts: []string{"*.example.com"},
AllowAllCommands: true,
DeniedCommands: []string{"rm -rf", "shutdown"},
},
Command: config.CommandConfig{
UseDefaults: boolPtr(false),
},
}
tests := []struct {
command string
shouldBlock bool
desc string
}{
// Chaining should still catch denied commands
{`ssh server.example.com "ls && rm -rf /"`, true, "rm -rf in chain blocked"},
{`ssh server.example.com "cat file; shutdown"`, true, "shutdown in chain blocked"},
// Chains without denied commands should work
{`ssh server.example.com "ls && cat && grep foo"`, false, "chain without denied commands"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
err := CheckSSHCommand(tt.command, cfg)
if tt.shouldBlock && err == nil {
t.Errorf("expected SSH command %q to be blocked", tt.command)
}
if !tt.shouldBlock && err != nil {
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
}
})
}
}