feat: support ssh commands (#10)
This commit is contained in:
@@ -82,6 +82,11 @@ func checkSingleCommand(command string, cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check SSH-specific policies if this is an SSH command
|
||||
if err := CheckSSHCommand(command, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -298,3 +303,222 @@ func matchesPrefix(command, prefix string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SSHBlockedError is returned when an SSH command is blocked by policy.
|
||||
type SSHBlockedError struct {
|
||||
Host string
|
||||
RemoteCommand string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *SSHBlockedError) Error() string {
|
||||
if e.RemoteCommand != "" {
|
||||
return fmt.Sprintf("SSH command blocked: %s (host: %s, command: %s)", e.Reason, e.Host, e.RemoteCommand)
|
||||
}
|
||||
return fmt.Sprintf("SSH blocked: %s (host: %s)", e.Reason, e.Host)
|
||||
}
|
||||
|
||||
// CheckSSHCommand checks if an SSH command is allowed by the configuration.
|
||||
// Returns nil if allowed, or SSHBlockedError if blocked.
|
||||
func CheckSSHCommand(command string, cfg *config.Config) error {
|
||||
if cfg == nil {
|
||||
cfg = config.Default()
|
||||
}
|
||||
|
||||
// Check if SSH config is active (has any hosts configured)
|
||||
// If no SSH policy is configured, allow by default
|
||||
if len(cfg.SSH.AllowedHosts) == 0 && len(cfg.SSH.DeniedHosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
host, remoteCmd, isSSH := parseSSHCommand(command)
|
||||
if !isSSH {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check host policy (denied then allowed)
|
||||
for _, pattern := range cfg.SSH.DeniedHosts {
|
||||
if config.MatchesHost(host, pattern) {
|
||||
return &SSHBlockedError{
|
||||
Host: host,
|
||||
RemoteCommand: remoteCmd,
|
||||
Reason: fmt.Sprintf("host matches denied pattern %q", pattern),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hostAllowed := false
|
||||
for _, pattern := range cfg.SSH.AllowedHosts {
|
||||
if config.MatchesHost(host, pattern) {
|
||||
hostAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.SSH.AllowedHosts) > 0 && !hostAllowed {
|
||||
return &SSHBlockedError{
|
||||
Host: host,
|
||||
RemoteCommand: remoteCmd,
|
||||
Reason: "host not in allowedHosts",
|
||||
}
|
||||
}
|
||||
|
||||
// If no remote command (interactive session), allow if host is allowed
|
||||
if remoteCmd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return checkSSHRemoteCommand(remoteCmd, cfg)
|
||||
}
|
||||
|
||||
// checkSSHRemoteCommand checks if a remote command is allowed by SSH policy.
|
||||
// It parses the remote command into subcommands (handling &&, ||, ;, |) and validates each.
|
||||
func checkSSHRemoteCommand(remoteCmd string, cfg *config.Config) error {
|
||||
// Parse into subcommands just like local commands to prevent bypass via chaining
|
||||
// e.g., "git status && rm -rf /" should check both "git status" and "rm -rf /"
|
||||
subCommands := parseShellCommand(remoteCmd)
|
||||
|
||||
for _, subCmd := range subCommands {
|
||||
if err := checkSSHSingleCommand(subCmd, remoteCmd, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSSHSingleCommand checks a single SSH remote command against policy.
|
||||
func checkSSHSingleCommand(subCmd, fullRemoteCmd string, cfg *config.Config) error {
|
||||
normalized := normalizeCommand(subCmd)
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check inherited global deny list first (if enabled)
|
||||
// User-defined global then default deny list
|
||||
if cfg.SSH.InheritDeny {
|
||||
for _, deny := range cfg.Command.Deny {
|
||||
if matchesPrefix(normalized, deny) {
|
||||
return &SSHBlockedError{
|
||||
RemoteCommand: fullRemoteCmd,
|
||||
Reason: fmt.Sprintf("command %q matches inherited global deny %q", subCmd, deny),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Command.UseDefaultDeniedCommands() {
|
||||
for _, deny := range config.DefaultDeniedCommands {
|
||||
if matchesPrefix(normalized, deny) {
|
||||
return &SSHBlockedError{
|
||||
RemoteCommand: fullRemoteCmd,
|
||||
Reason: fmt.Sprintf("command %q matches inherited default deny %q", subCmd, deny),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check SSH-specific denied commands
|
||||
for _, deny := range cfg.SSH.DeniedCommands {
|
||||
if matchesPrefix(normalized, deny) {
|
||||
return &SSHBlockedError{
|
||||
RemoteCommand: fullRemoteCmd,
|
||||
Reason: fmt.Sprintf("command %q matches ssh.deniedCommands %q", subCmd, deny),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If allowAllCommands is true, we're in denylist mode - allow anything not denied
|
||||
if cfg.SSH.AllowAllCommands {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Allowlist mode: check if command is in allowedCommands
|
||||
if len(cfg.SSH.AllowedCommands) > 0 {
|
||||
for _, allow := range cfg.SSH.AllowedCommands {
|
||||
if matchesPrefix(normalized, allow) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Not in allowlist
|
||||
return &SSHBlockedError{
|
||||
RemoteCommand: fullRemoteCmd,
|
||||
Reason: fmt.Sprintf("command %q not in ssh.allowedCommands", subCmd),
|
||||
}
|
||||
}
|
||||
|
||||
// No allowedCommands configured and not in denylist mode = deny all remote commands
|
||||
return &SSHBlockedError{
|
||||
RemoteCommand: fullRemoteCmd,
|
||||
Reason: "no ssh.allowedCommands configured (allowlist mode requires explicit commands)",
|
||||
}
|
||||
}
|
||||
|
||||
// parseSSHCommand parses an SSH command and extracts the host and remote command.
|
||||
// Returns (host, remoteCommand, isSSH).
|
||||
func parseSSHCommand(command string) (string, string, bool) {
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
tokens := tokenizeCommand(command)
|
||||
if len(tokens) == 0 {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
cmdName := filepath.Base(tokens[0])
|
||||
if cmdName != "ssh" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Parse SSH arguments to find host and command
|
||||
// SSH syntax: ssh [options] [user@]hostname [command]
|
||||
var host string
|
||||
var remoteCmd string
|
||||
skipNext := false
|
||||
|
||||
for i := 1; i < len(tokens); i++ {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
arg := tokens[i]
|
||||
|
||||
// Skip options that take arguments
|
||||
if arg == "-p" || arg == "-l" || arg == "-i" || arg == "-o" ||
|
||||
arg == "-F" || arg == "-J" || arg == "-W" || arg == "-b" ||
|
||||
arg == "-c" || arg == "-D" || arg == "-E" || arg == "-e" ||
|
||||
arg == "-I" || arg == "-L" || arg == "-m" || arg == "-O" ||
|
||||
arg == "-Q" || arg == "-R" || arg == "-S" || arg == "-w" {
|
||||
skipNext = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip single-char options (like -v, -t, -n, etc.)
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
|
||||
// First non-option argument is the host
|
||||
if host == "" {
|
||||
host = arg
|
||||
// Extract the hostname from user@host format
|
||||
if atIdx := strings.LastIndex(host, "@"); atIdx >= 0 {
|
||||
host = host[atIdx+1:]
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Remaining arguments form the remote command
|
||||
remoteCmd = strings.Join(tokens[i:], " ")
|
||||
break
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return host, remoteCmd, true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user