feat: support ssh commands (#10)
This commit is contained in:
@@ -14,6 +14,7 @@ You can also think of Fence as a permission manager for your CLI coding agents.
|
|||||||
- **Domain Allowlisting**: Configure which domains are allowed
|
- **Domain Allowlisting**: Configure which domains are allowed
|
||||||
- **Filesystem Restrictions**: Control read/write access to paths
|
- **Filesystem Restrictions**: Control read/write access to paths
|
||||||
- **Command Blocking**: Block dangerous commands (e.g., `shutdown`, `rm -rf`) with configurable deny/allow lists
|
- **Command Blocking**: Block dangerous commands (e.g., `shutdown`, `rm -rf`) with configurable deny/allow lists
|
||||||
|
- **SSH Command Filtering**: Control which hosts and commands are allowed over SSH
|
||||||
- **Violation Monitoring**: Real-time logging of blocked requests and sandbox denials
|
- **Violation Monitoring**: Real-time logging of blocked requests and sandbox denials
|
||||||
- **Cross-Platform**: macOS (sandbox-exec) and Linux (bubblewrap)
|
- **Cross-Platform**: macOS (sandbox-exec) and Linux (bubblewrap)
|
||||||
- **HTTP/SOCKS5 Proxies**: Built-in filtering proxies for domain control
|
- **HTTP/SOCKS5 Proxies**: Built-in filtering proxies for domain control
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ Example config:
|
|||||||
},
|
},
|
||||||
"command": {
|
"command": {
|
||||||
"deny": ["git push", "npm publish"]
|
"deny": ["git push", "npm publish"]
|
||||||
|
},
|
||||||
|
"ssh": {
|
||||||
|
"allowedHosts": ["*.example.com"],
|
||||||
|
"allowedCommands": ["ls", "cat", "grep", "tail", "head"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -158,6 +162,96 @@ Fence detects blocked commands in:
|
|||||||
- Pipelines: `echo test | git push`
|
- Pipelines: `echo test | git push`
|
||||||
- Shell invocations: `bash -c "git push"` or `sh -lc "ls && git push"`
|
- Shell invocations: `bash -c "git push"` or `sh -lc "ls && git push"`
|
||||||
|
|
||||||
|
## SSH Configuration
|
||||||
|
|
||||||
|
Control which SSH commands are allowed. By default, SSH uses **allowlist mode** for security - only explicitly allowed hosts and commands can be used.
|
||||||
|
|
||||||
|
| Field | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| `allowedHosts` | Host patterns to allow SSH connections to (supports wildcards like `*.example.com`, `prod-*`) |
|
||||||
|
| `deniedHosts` | Host patterns to deny SSH connections to (checked before allowed) |
|
||||||
|
| `allowedCommands` | Commands allowed over SSH (allowlist mode) |
|
||||||
|
| `deniedCommands` | Commands denied over SSH (checked before allowed) |
|
||||||
|
| `allowAllCommands` | If `true`, use denylist mode instead of allowlist (allow all commands except denied) |
|
||||||
|
| `inheritDeny` | If `true`, also apply global `command.deny` rules to SSH commands |
|
||||||
|
|
||||||
|
### Basic Example (Allowlist Mode)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ssh": {
|
||||||
|
"allowedHosts": ["*.example.com"],
|
||||||
|
"allowedCommands": ["ls", "cat", "grep", "tail", "head", "find"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows:
|
||||||
|
|
||||||
|
- SSH to any `*.example.com` host
|
||||||
|
- Only the listed commands (and their arguments)
|
||||||
|
- Interactive sessions (no remote command)
|
||||||
|
|
||||||
|
### Denylist Mode Example
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ssh": {
|
||||||
|
"allowedHosts": ["dev-*.example.com"],
|
||||||
|
"allowAllCommands": true,
|
||||||
|
"deniedCommands": ["rm -rf", "shutdown", "chmod"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows:
|
||||||
|
|
||||||
|
- SSH to any `dev-*.example.com` host
|
||||||
|
- Any command except the denied ones
|
||||||
|
|
||||||
|
### Inheriting Global Denies
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"command": {
|
||||||
|
"deny": ["shutdown", "reboot", "rm -rf /"]
|
||||||
|
},
|
||||||
|
"ssh": {
|
||||||
|
"allowedHosts": ["*.example.com"],
|
||||||
|
"allowAllCommands": true,
|
||||||
|
"inheritDeny": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
With `inheritDeny: true`, SSH commands also check against:
|
||||||
|
|
||||||
|
- Global `command.deny` list
|
||||||
|
- Default denied commands (if `command.useDefaults` is true)
|
||||||
|
|
||||||
|
### Host Pattern Matching
|
||||||
|
|
||||||
|
SSH host patterns support wildcards anywhere:
|
||||||
|
|
||||||
|
| Pattern | Matches |
|
||||||
|
|---------|---------|
|
||||||
|
| `server1.example.com` | Exact match only |
|
||||||
|
| `*.example.com` | Any subdomain of example.com |
|
||||||
|
| `prod-*` | Any hostname starting with `prod-` |
|
||||||
|
| `prod-*.us-east.*` | Multiple wildcards |
|
||||||
|
| `*` | All hosts |
|
||||||
|
|
||||||
|
### Evaluation Order
|
||||||
|
|
||||||
|
1. Check if host matches `deniedHosts` → **DENY**
|
||||||
|
2. Check if host matches `allowedHosts` → continue (else **DENY**)
|
||||||
|
3. If no remote command (interactive session) → **ALLOW**
|
||||||
|
4. Check if command matches `deniedCommands` → **DENY**
|
||||||
|
5. If `inheritDeny`, check global `command.deny` → **DENY**
|
||||||
|
6. If `allowAllCommands` → **ALLOW**
|
||||||
|
7. Check if command matches `allowedCommands` → **ALLOW**
|
||||||
|
8. Default → **DENY**
|
||||||
|
|
||||||
## Other Options
|
## Other Options
|
||||||
|
|
||||||
| Field | Description |
|
| Field | Description |
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type Config struct {
|
|||||||
Network NetworkConfig `json:"network"`
|
Network NetworkConfig `json:"network"`
|
||||||
Filesystem FilesystemConfig `json:"filesystem"`
|
Filesystem FilesystemConfig `json:"filesystem"`
|
||||||
Command CommandConfig `json:"command"`
|
Command CommandConfig `json:"command"`
|
||||||
|
SSH SSHConfig `json:"ssh"`
|
||||||
AllowPty bool `json:"allowPty,omitempty"`
|
AllowPty bool `json:"allowPty,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,6 +50,17 @@ type CommandConfig struct {
|
|||||||
UseDefaults *bool `json:"useDefaults,omitempty"`
|
UseDefaults *bool `json:"useDefaults,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHConfig defines SSH command restrictions.
|
||||||
|
// SSH commands are filtered using an allowlist by default for security.
|
||||||
|
type SSHConfig struct {
|
||||||
|
AllowedHosts []string `json:"allowedHosts"` // Host patterns to allow SSH to (supports wildcards like *.example.com)
|
||||||
|
DeniedHosts []string `json:"deniedHosts"` // Host patterns to deny SSH to (checked before allowed)
|
||||||
|
AllowedCommands []string `json:"allowedCommands"` // Commands allowed over SSH (allowlist mode)
|
||||||
|
DeniedCommands []string `json:"deniedCommands"` // Commands denied over SSH (checked before allowed)
|
||||||
|
AllowAllCommands bool `json:"allowAllCommands,omitempty"` // If true, use denylist mode instead of allowlist
|
||||||
|
InheritDeny bool `json:"inheritDeny,omitempty"` // If true, also apply global command.deny rules
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultDeniedCommands returns commands that are blocked by default.
|
// DefaultDeniedCommands returns commands that are blocked by default.
|
||||||
// These are system-level dangerous commands that are rarely needed by AI agents.
|
// These are system-level dangerous commands that are rarely needed by AI agents.
|
||||||
var DefaultDeniedCommands = []string{
|
var DefaultDeniedCommands = []string{
|
||||||
@@ -109,6 +121,12 @@ func Default() *Config {
|
|||||||
Allow: []string{},
|
Allow: []string{},
|
||||||
// UseDefaults defaults to true (nil = true)
|
// UseDefaults defaults to true (nil = true)
|
||||||
},
|
},
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{},
|
||||||
|
DeniedHosts: []string{},
|
||||||
|
AllowedCommands: []string{},
|
||||||
|
DeniedCommands: []string{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,6 +196,24 @@ func (c *Config) Validate() error {
|
|||||||
return errors.New("command.allow contains empty command")
|
return errors.New("command.allow contains empty command")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSH config
|
||||||
|
for _, host := range c.SSH.AllowedHosts {
|
||||||
|
if err := validateHostPattern(host); err != nil {
|
||||||
|
return fmt.Errorf("invalid ssh.allowedHosts %q: %w", host, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, host := range c.SSH.DeniedHosts {
|
||||||
|
if err := validateHostPattern(host); err != nil {
|
||||||
|
return fmt.Errorf("invalid ssh.deniedHosts %q: %w", host, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if slices.Contains(c.SSH.AllowedCommands, "") {
|
||||||
|
return errors.New("ssh.allowedCommands contains empty command")
|
||||||
|
}
|
||||||
|
if slices.Contains(c.SSH.DeniedCommands, "") {
|
||||||
|
return errors.New("ssh.deniedCommands contains empty command")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,6 +265,42 @@ func validateDomainPattern(pattern string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateHostPattern validates an SSH host pattern.
|
||||||
|
// Host patterns are more permissive than domain patterns:
|
||||||
|
// - Can contain wildcards anywhere (e.g., prod-*.example.com, *.example.com)
|
||||||
|
// - Can be IP addresses
|
||||||
|
// - Can be simple hostnames without dots
|
||||||
|
func validateHostPattern(pattern string) error {
|
||||||
|
if pattern == "" {
|
||||||
|
return errors.New("empty host pattern")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject patterns with protocol or path
|
||||||
|
if strings.Contains(pattern, "://") || strings.Contains(pattern, "/") {
|
||||||
|
return errors.New("host pattern cannot contain protocol or path")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject patterns with port (user@host:port style)
|
||||||
|
// But allow colons for IPv6 addresses
|
||||||
|
if strings.Contains(pattern, ":") && !strings.Contains(pattern, "::") && !isIPv6Pattern(pattern) {
|
||||||
|
return errors.New("host pattern cannot contain port; specify port in SSH command instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject patterns with @ (should be just the host, not user@host)
|
||||||
|
if strings.Contains(pattern, "@") {
|
||||||
|
return errors.New("host pattern should not contain username; specify just the host")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPv6Pattern checks if a pattern looks like an IPv6 address.
|
||||||
|
func isIPv6Pattern(pattern string) bool {
|
||||||
|
// IPv6 addresses contain multiple colons
|
||||||
|
colonCount := strings.Count(pattern, ":")
|
||||||
|
return colonCount >= 2
|
||||||
|
}
|
||||||
|
|
||||||
// MatchesDomain checks if a hostname matches a domain pattern.
|
// MatchesDomain checks if a hostname matches a domain pattern.
|
||||||
func MatchesDomain(hostname, pattern string) bool {
|
func MatchesDomain(hostname, pattern string) bool {
|
||||||
hostname = strings.ToLower(hostname)
|
hostname = strings.ToLower(hostname)
|
||||||
@@ -249,6 +321,71 @@ func MatchesDomain(hostname, pattern string) bool {
|
|||||||
return hostname == pattern
|
return hostname == pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MatchesHost checks if a hostname matches an SSH host pattern.
|
||||||
|
// SSH host patterns support wildcards anywhere in the pattern.
|
||||||
|
func MatchesHost(hostname, pattern string) bool {
|
||||||
|
hostname = strings.ToLower(hostname)
|
||||||
|
pattern = strings.ToLower(pattern)
|
||||||
|
|
||||||
|
// "*" matches all hosts
|
||||||
|
if pattern == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If pattern contains no wildcards, do exact match
|
||||||
|
if !strings.Contains(pattern, "*") {
|
||||||
|
return hostname == pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert glob pattern to a simple matcher
|
||||||
|
// Split pattern by * and check each part
|
||||||
|
return matchGlob(hostname, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchGlob performs simple glob matching with * wildcards.
|
||||||
|
func matchGlob(s, pattern string) bool {
|
||||||
|
// Handle edge cases
|
||||||
|
if pattern == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if pattern == "" {
|
||||||
|
return s == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split pattern by * and match parts
|
||||||
|
parts := strings.Split(pattern, "*")
|
||||||
|
|
||||||
|
// Check prefix (before first *)
|
||||||
|
if !strings.HasPrefix(s, parts[0]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s = s[len(parts[0]):]
|
||||||
|
|
||||||
|
// Check suffix (after last *)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
last := parts[len(parts)-1]
|
||||||
|
if !strings.HasSuffix(s, last) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s = s[:len(s)-len(last)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check middle parts (between *s)
|
||||||
|
for i := 1; i < len(parts)-1; i++ {
|
||||||
|
part := parts[i]
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idx := strings.Index(s, part)
|
||||||
|
if idx < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s = s[idx+len(part):]
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Merge combines a base config with an override config.
|
// Merge combines a base config with an override config.
|
||||||
// Values in override take precedence. Slice fields are appended (base + override).
|
// Values in override take precedence. Slice fields are appended (base + override).
|
||||||
// The Extends field is cleared in the result since inheritance has been resolved.
|
// The Extends field is cleared in the result since inheritance has been resolved.
|
||||||
@@ -307,6 +444,18 @@ func Merge(base, override *Config) *Config {
|
|||||||
// Pointer field: override wins if set
|
// Pointer field: override wins if set
|
||||||
UseDefaults: mergeOptionalBool(base.Command.UseDefaults, override.Command.UseDefaults),
|
UseDefaults: mergeOptionalBool(base.Command.UseDefaults, override.Command.UseDefaults),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
SSH: SSHConfig{
|
||||||
|
// Append slices
|
||||||
|
AllowedHosts: mergeStrings(base.SSH.AllowedHosts, override.SSH.AllowedHosts),
|
||||||
|
DeniedHosts: mergeStrings(base.SSH.DeniedHosts, override.SSH.DeniedHosts),
|
||||||
|
AllowedCommands: mergeStrings(base.SSH.AllowedCommands, override.SSH.AllowedCommands),
|
||||||
|
DeniedCommands: mergeStrings(base.SSH.DeniedCommands, override.SSH.DeniedCommands),
|
||||||
|
|
||||||
|
// Boolean fields: true if either enables it
|
||||||
|
AllowAllCommands: base.SSH.AllowAllCommands || override.SSH.AllowAllCommands,
|
||||||
|
InheritDeny: base.SSH.InheritDeny || override.SSH.InheritDeny,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -480,3 +480,210 @@ func TestMerge(t *testing.T) {
|
|||||||
func boolPtr(b bool) *bool {
|
func boolPtr(b bool) *bool {
|
||||||
return &b
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateHostPattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pattern string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// Valid patterns
|
||||||
|
{"simple hostname", "server1", false},
|
||||||
|
{"domain", "example.com", false},
|
||||||
|
{"subdomain", "prod.example.com", false},
|
||||||
|
{"wildcard prefix", "*.example.com", false},
|
||||||
|
{"wildcard middle", "prod-*.example.com", false},
|
||||||
|
{"ip address", "192.168.1.1", false},
|
||||||
|
{"ipv6 address", "::1", false},
|
||||||
|
{"ipv6 full", "2001:db8::1", false},
|
||||||
|
{"localhost", "localhost", false},
|
||||||
|
|
||||||
|
// Invalid patterns
|
||||||
|
{"empty", "", true},
|
||||||
|
{"with protocol", "ssh://example.com", true},
|
||||||
|
{"with path", "example.com/path", true},
|
||||||
|
{"with port", "example.com:22", true},
|
||||||
|
{"with username", "user@example.com", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateHostPattern(tt.pattern)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateHostPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchesHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hostname string
|
||||||
|
pattern string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// Exact matches
|
||||||
|
{"exact match", "server1.example.com", "server1.example.com", true},
|
||||||
|
{"exact match case insensitive", "Server1.Example.COM", "server1.example.com", true},
|
||||||
|
{"exact no match", "server2.example.com", "server1.example.com", false},
|
||||||
|
|
||||||
|
// Wildcard matches
|
||||||
|
{"wildcard prefix", "api.example.com", "*.example.com", true},
|
||||||
|
{"wildcard prefix deep", "deep.api.example.com", "*.example.com", true},
|
||||||
|
{"wildcard no match base", "example.com", "*.example.com", false},
|
||||||
|
{"wildcard middle", "prod-web-01.example.com", "prod-*.example.com", true},
|
||||||
|
{"wildcard middle no match", "dev-web-01.example.com", "prod-*.example.com", false},
|
||||||
|
{"wildcard suffix", "server1.prod", "server1.*", true},
|
||||||
|
{"multiple wildcards", "prod-web-01.us-east.example.com", "prod-*-*.example.com", true},
|
||||||
|
|
||||||
|
// Star matches all
|
||||||
|
{"star matches all", "anything.example.com", "*", true},
|
||||||
|
|
||||||
|
// IP addresses
|
||||||
|
{"ip exact match", "192.168.1.1", "192.168.1.1", true},
|
||||||
|
{"ip no match", "192.168.1.2", "192.168.1.1", false},
|
||||||
|
{"ip wildcard", "192.168.1.100", "192.168.1.*", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MatchesHost(tt.hostname, tt.pattern)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("MatchesHost(%q, %q) = %v, want %v", tt.hostname, tt.pattern, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid SSH config",
|
||||||
|
config: Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com", "prod-*.internal"},
|
||||||
|
AllowedCommands: []string{"ls", "cat", "grep"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid allowed host with protocol",
|
||||||
|
config: Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{"ssh://example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid denied host with username",
|
||||||
|
config: Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
DeniedHosts: []string{"user@example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty allowed command",
|
||||||
|
config: Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{"example.com"},
|
||||||
|
AllowedCommands: []string{"ls", ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty denied command",
|
||||||
|
config: Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{"example.com"},
|
||||||
|
DeniedCommands: []string{""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.Validate()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeSSHConfig(t *testing.T) {
|
||||||
|
t.Run("merge SSH allowed hosts", func(t *testing.T) {
|
||||||
|
base := &Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{"prod-*.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
override := &Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedHosts: []string{"dev-*.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := Merge(base, override)
|
||||||
|
|
||||||
|
if len(result.SSH.AllowedHosts) != 2 {
|
||||||
|
t.Errorf("expected 2 allowed hosts, got %d: %v", len(result.SSH.AllowedHosts), result.SSH.AllowedHosts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merge SSH commands", func(t *testing.T) {
|
||||||
|
base := &Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedCommands: []string{"ls", "cat"},
|
||||||
|
DeniedCommands: []string{"rm -rf"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
override := &Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowedCommands: []string{"grep", "find"},
|
||||||
|
DeniedCommands: []string{"shutdown"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := Merge(base, override)
|
||||||
|
|
||||||
|
if len(result.SSH.AllowedCommands) != 4 {
|
||||||
|
t.Errorf("expected 4 allowed commands, got %d", len(result.SSH.AllowedCommands))
|
||||||
|
}
|
||||||
|
if len(result.SSH.DeniedCommands) != 2 {
|
||||||
|
t.Errorf("expected 2 denied commands, got %d", len(result.SSH.DeniedCommands))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merge SSH boolean flags", func(t *testing.T) {
|
||||||
|
base := &Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowAllCommands: false,
|
||||||
|
InheritDeny: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
override := &Config{
|
||||||
|
SSH: SSHConfig{
|
||||||
|
AllowAllCommands: true,
|
||||||
|
InheritDeny: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := Merge(base, override)
|
||||||
|
|
||||||
|
if !result.SSH.AllowAllCommands {
|
||||||
|
t.Error("expected AllowAllCommands to be true (OR logic)")
|
||||||
|
}
|
||||||
|
if !result.SSH.InheritDeny {
|
||||||
|
t.Error("expected InheritDeny to be true (OR logic)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -82,6 +82,11 @@ func checkSingleCommand(command string, cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check SSH-specific policies if this is an SSH command
|
||||||
|
if err := CheckSSHCommand(command, cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,3 +303,222 @@ func matchesPrefix(command, prefix string) bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHBlockedError is returned when an SSH command is blocked by policy.
|
||||||
|
type SSHBlockedError struct {
|
||||||
|
Host string
|
||||||
|
RemoteCommand string
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SSHBlockedError) Error() string {
|
||||||
|
if e.RemoteCommand != "" {
|
||||||
|
return fmt.Sprintf("SSH command blocked: %s (host: %s, command: %s)", e.Reason, e.Host, e.RemoteCommand)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("SSH blocked: %s (host: %s)", e.Reason, e.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckSSHCommand checks if an SSH command is allowed by the configuration.
|
||||||
|
// Returns nil if allowed, or SSHBlockedError if blocked.
|
||||||
|
func CheckSSHCommand(command string, cfg *config.Config) error {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = config.Default()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if SSH config is active (has any hosts configured)
|
||||||
|
// If no SSH policy is configured, allow by default
|
||||||
|
if len(cfg.SSH.AllowedHosts) == 0 && len(cfg.SSH.DeniedHosts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
host, remoteCmd, isSSH := parseSSHCommand(command)
|
||||||
|
if !isSSH {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check host policy (denied then allowed)
|
||||||
|
for _, pattern := range cfg.SSH.DeniedHosts {
|
||||||
|
if config.MatchesHost(host, pattern) {
|
||||||
|
return &SSHBlockedError{
|
||||||
|
Host: host,
|
||||||
|
RemoteCommand: remoteCmd,
|
||||||
|
Reason: fmt.Sprintf("host matches denied pattern %q", pattern),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hostAllowed := false
|
||||||
|
for _, pattern := range cfg.SSH.AllowedHosts {
|
||||||
|
if config.MatchesHost(host, pattern) {
|
||||||
|
hostAllowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.SSH.AllowedHosts) > 0 && !hostAllowed {
|
||||||
|
return &SSHBlockedError{
|
||||||
|
Host: host,
|
||||||
|
RemoteCommand: remoteCmd,
|
||||||
|
Reason: "host not in allowedHosts",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no remote command (interactive session), allow if host is allowed
|
||||||
|
if remoteCmd == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return checkSSHRemoteCommand(remoteCmd, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSSHRemoteCommand checks if a remote command is allowed by SSH policy.
|
||||||
|
// It parses the remote command into subcommands (handling &&, ||, ;, |) and validates each.
|
||||||
|
func checkSSHRemoteCommand(remoteCmd string, cfg *config.Config) error {
|
||||||
|
// Parse into subcommands just like local commands to prevent bypass via chaining
|
||||||
|
// e.g., "git status && rm -rf /" should check both "git status" and "rm -rf /"
|
||||||
|
subCommands := parseShellCommand(remoteCmd)
|
||||||
|
|
||||||
|
for _, subCmd := range subCommands {
|
||||||
|
if err := checkSSHSingleCommand(subCmd, remoteCmd, cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSSHSingleCommand checks a single SSH remote command against policy.
|
||||||
|
func checkSSHSingleCommand(subCmd, fullRemoteCmd string, cfg *config.Config) error {
|
||||||
|
normalized := normalizeCommand(subCmd)
|
||||||
|
if normalized == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check inherited global deny list first (if enabled)
|
||||||
|
// User-defined global then default deny list
|
||||||
|
if cfg.SSH.InheritDeny {
|
||||||
|
for _, deny := range cfg.Command.Deny {
|
||||||
|
if matchesPrefix(normalized, deny) {
|
||||||
|
return &SSHBlockedError{
|
||||||
|
RemoteCommand: fullRemoteCmd,
|
||||||
|
Reason: fmt.Sprintf("command %q matches inherited global deny %q", subCmd, deny),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Command.UseDefaultDeniedCommands() {
|
||||||
|
for _, deny := range config.DefaultDeniedCommands {
|
||||||
|
if matchesPrefix(normalized, deny) {
|
||||||
|
return &SSHBlockedError{
|
||||||
|
RemoteCommand: fullRemoteCmd,
|
||||||
|
Reason: fmt.Sprintf("command %q matches inherited default deny %q", subCmd, deny),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check SSH-specific denied commands
|
||||||
|
for _, deny := range cfg.SSH.DeniedCommands {
|
||||||
|
if matchesPrefix(normalized, deny) {
|
||||||
|
return &SSHBlockedError{
|
||||||
|
RemoteCommand: fullRemoteCmd,
|
||||||
|
Reason: fmt.Sprintf("command %q matches ssh.deniedCommands %q", subCmd, deny),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If allowAllCommands is true, we're in denylist mode - allow anything not denied
|
||||||
|
if cfg.SSH.AllowAllCommands {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allowlist mode: check if command is in allowedCommands
|
||||||
|
if len(cfg.SSH.AllowedCommands) > 0 {
|
||||||
|
for _, allow := range cfg.SSH.AllowedCommands {
|
||||||
|
if matchesPrefix(normalized, allow) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Not in allowlist
|
||||||
|
return &SSHBlockedError{
|
||||||
|
RemoteCommand: fullRemoteCmd,
|
||||||
|
Reason: fmt.Sprintf("command %q not in ssh.allowedCommands", subCmd),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No allowedCommands configured and not in denylist mode = deny all remote commands
|
||||||
|
return &SSHBlockedError{
|
||||||
|
RemoteCommand: fullRemoteCmd,
|
||||||
|
Reason: "no ssh.allowedCommands configured (allowlist mode requires explicit commands)",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSSHCommand parses an SSH command and extracts the host and remote command.
|
||||||
|
// Returns (host, remoteCommand, isSSH).
|
||||||
|
func parseSSHCommand(command string) (string, string, bool) {
|
||||||
|
command = strings.TrimSpace(command)
|
||||||
|
if command == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := tokenizeCommand(command)
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdName := filepath.Base(tokens[0])
|
||||||
|
if cmdName != "ssh" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSH arguments to find host and command
|
||||||
|
// SSH syntax: ssh [options] [user@]hostname [command]
|
||||||
|
var host string
|
||||||
|
var remoteCmd string
|
||||||
|
skipNext := false
|
||||||
|
|
||||||
|
for i := 1; i < len(tokens); i++ {
|
||||||
|
if skipNext {
|
||||||
|
skipNext = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
arg := tokens[i]
|
||||||
|
|
||||||
|
// Skip options that take arguments
|
||||||
|
if arg == "-p" || arg == "-l" || arg == "-i" || arg == "-o" ||
|
||||||
|
arg == "-F" || arg == "-J" || arg == "-W" || arg == "-b" ||
|
||||||
|
arg == "-c" || arg == "-D" || arg == "-E" || arg == "-e" ||
|
||||||
|
arg == "-I" || arg == "-L" || arg == "-m" || arg == "-O" ||
|
||||||
|
arg == "-Q" || arg == "-R" || arg == "-S" || arg == "-w" {
|
||||||
|
skipNext = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip single-char options (like -v, -t, -n, etc.)
|
||||||
|
if strings.HasPrefix(arg, "-") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// First non-option argument is the host
|
||||||
|
if host == "" {
|
||||||
|
host = arg
|
||||||
|
// Extract the hostname from user@host format
|
||||||
|
if atIdx := strings.LastIndex(host, "@"); atIdx >= 0 {
|
||||||
|
host = host[atIdx+1:]
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remaining arguments form the remote command
|
||||||
|
remoteCmd = strings.Join(tokens[i:], " ")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return host, remoteCmd, true
|
||||||
|
}
|
||||||
|
|||||||
@@ -454,3 +454,373 @@ func TestMatchesPrefix(t *testing.T) {
|
|||||||
func boolPtr(b bool) *bool {
|
func boolPtr(b bool) *bool {
|
||||||
return &b
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseSSHCommand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
wantHost string
|
||||||
|
wantCmd string
|
||||||
|
wantIsSSH bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Basic SSH commands
|
||||||
|
{`ssh server1.example.com`, "server1.example.com", "", true, "simple host"},
|
||||||
|
{`ssh user@server1.example.com`, "server1.example.com", "", true, "user@host"},
|
||||||
|
{`ssh server1.example.com ls -la`, "server1.example.com", "ls -la", true, "host with command"},
|
||||||
|
{`ssh user@server1.example.com "cat /var/log/app.log"`, "server1.example.com", `cat /var/log/app.log`, true, "user@host with quoted command"},
|
||||||
|
|
||||||
|
// SSH with options
|
||||||
|
{`ssh -p 2222 server1.example.com`, "server1.example.com", "", true, "with port option"},
|
||||||
|
{`ssh -i ~/.ssh/key server1.example.com ls`, "server1.example.com", "ls", true, "with identity file"},
|
||||||
|
{`ssh -v -t server1.example.com`, "server1.example.com", "", true, "with flags"},
|
||||||
|
{`ssh -o StrictHostKeyChecking=no server1.example.com`, "server1.example.com", "", true, "with -o option"},
|
||||||
|
|
||||||
|
// Full path to ssh
|
||||||
|
{`/usr/bin/ssh server1.example.com ls`, "server1.example.com", "ls", true, "full path ssh"},
|
||||||
|
|
||||||
|
// Not SSH commands
|
||||||
|
{`ls -la`, "", "", false, "not ssh"},
|
||||||
|
{`sshpass -p password ssh server`, "", "", false, "sshpass wrapper"},
|
||||||
|
{`echo ssh server`, "", "", false, "ssh as argument"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
host, cmd, isSSH := parseSSHCommand(tt.command)
|
||||||
|
if isSSH != tt.wantIsSSH {
|
||||||
|
t.Errorf("parseSSHCommand(%q) isSSH = %v, want %v", tt.command, isSSH, tt.wantIsSSH)
|
||||||
|
}
|
||||||
|
if host != tt.wantHost {
|
||||||
|
t.Errorf("parseSSHCommand(%q) host = %q, want %q", tt.command, host, tt.wantHost)
|
||||||
|
}
|
||||||
|
if cmd != tt.wantCmd {
|
||||||
|
t.Errorf("parseSSHCommand(%q) cmd = %q, want %q", tt.command, cmd, tt.wantCmd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_HostPolicy(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com", "prod-*"},
|
||||||
|
DeniedHosts: []string{"prod-db.example.com"},
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Allowed hosts
|
||||||
|
{`ssh server1.example.com`, false, "allowed by wildcard"},
|
||||||
|
{`ssh api.example.com`, false, "allowed subdomain"},
|
||||||
|
{`ssh prod-web-01`, false, "allowed by prod-* pattern"},
|
||||||
|
|
||||||
|
// Denied hosts
|
||||||
|
{`ssh prod-db.example.com`, true, "explicitly denied"},
|
||||||
|
|
||||||
|
// Not in allowlist
|
||||||
|
{`ssh other.domain.com`, true, "not in allowedHosts"},
|
||||||
|
{`ssh dev-server`, true, "not matching any pattern"},
|
||||||
|
|
||||||
|
// Non-SSH commands (should pass through)
|
||||||
|
{`ls -la`, false, "not an SSH command"},
|
||||||
|
{`curl https://example.com`, false, "not an SSH command"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_AllowlistMode(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com"},
|
||||||
|
AllowedCommands: []string{"ls", "cat", "grep", "tail -f"},
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Allowed commands
|
||||||
|
{`ssh server.example.com ls`, false, "ls allowed"},
|
||||||
|
{`ssh server.example.com ls -la /var/log`, false, "ls with args"},
|
||||||
|
{`ssh server.example.com cat /etc/hosts`, false, "cat allowed"},
|
||||||
|
{`ssh server.example.com grep error /var/log/app.log`, false, "grep allowed"},
|
||||||
|
{`ssh server.example.com tail -f /var/log/app.log`, false, "tail -f allowed"},
|
||||||
|
|
||||||
|
// Not in allowlist
|
||||||
|
{`ssh server.example.com rm -rf /tmp/cache`, true, "rm not in allowlist"},
|
||||||
|
{`ssh server.example.com chmod 777 /tmp`, true, "chmod not in allowlist"},
|
||||||
|
{`ssh server.example.com shutdown now`, true, "shutdown not in allowlist"},
|
||||||
|
{`ssh server.example.com tail /var/log/app.log`, true, "tail without -f not allowed"},
|
||||||
|
|
||||||
|
// Interactive session (no command) - should be allowed
|
||||||
|
{`ssh server.example.com`, false, "interactive session allowed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_DenylistMode(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com"},
|
||||||
|
AllowAllCommands: true, // denylist mode
|
||||||
|
DeniedCommands: []string{"rm -rf", "shutdown", "chmod"},
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Allowed (not in denylist)
|
||||||
|
{`ssh server.example.com ls -la`, false, "ls allowed"},
|
||||||
|
{`ssh server.example.com cat /etc/hosts`, false, "cat allowed"},
|
||||||
|
{`ssh server.example.com rm file.txt`, false, "rm single file allowed"},
|
||||||
|
{`ssh server.example.com apt-get update`, false, "apt-get allowed"},
|
||||||
|
|
||||||
|
// Denied
|
||||||
|
{`ssh server.example.com rm -rf /tmp/cache`, true, "rm -rf denied"},
|
||||||
|
{`ssh server.example.com shutdown now`, true, "shutdown denied"},
|
||||||
|
{`ssh server.example.com chmod 777 /tmp`, true, "chmod denied"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_InheritDeny(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com"},
|
||||||
|
AllowAllCommands: true, // denylist mode
|
||||||
|
InheritDeny: true, // inherit global denies
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
Deny: []string{"git push", "npm publish"},
|
||||||
|
UseDefaults: boolPtr(true), // include default denies like shutdown
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Inherited from global deny
|
||||||
|
{`ssh server.example.com git push origin main`, true, "git push from global deny"},
|
||||||
|
{`ssh server.example.com npm publish`, true, "npm publish from global deny"},
|
||||||
|
|
||||||
|
// Inherited from default deny list
|
||||||
|
{`ssh server.example.com shutdown now`, true, "shutdown from default deny"},
|
||||||
|
{`ssh server.example.com reboot`, true, "reboot from default deny"},
|
||||||
|
|
||||||
|
// Allowed (not in any deny list)
|
||||||
|
{`ssh server.example.com ls -la`, false, "ls allowed"},
|
||||||
|
{`ssh server.example.com git status`, false, "git status allowed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_NoSSHConfig(t *testing.T) {
|
||||||
|
// No SSH policy configured - all SSH commands should pass through
|
||||||
|
cfg := &config.Config{
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{`ssh server.example.com rm -rf /`, "dangerous command allowed when no SSH policy"},
|
||||||
|
{`ssh any-host.com shutdown`, "any host allowed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed (no SSH policy), got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckCommand_IntegratesSSH(t *testing.T) {
|
||||||
|
// Test that CheckCommand also checks SSH policies
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com"},
|
||||||
|
AllowedCommands: []string{"ls", "cat"},
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Via CheckCommand, SSH policy should be enforced
|
||||||
|
{`ssh server.example.com ls`, false, "allowed SSH command"},
|
||||||
|
{`ssh server.example.com rm -rf /`, true, "blocked SSH command"},
|
||||||
|
{`ssh other.com ls`, true, "blocked host"},
|
||||||
|
|
||||||
|
// Non-SSH commands unaffected
|
||||||
|
{`ls -la`, false, "local ls allowed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_CommandChaining(t *testing.T) {
|
||||||
|
// Test that command chaining doesn't bypass allow/deny rules
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com"},
|
||||||
|
AllowedCommands: []string{"ls", "cat", "git status"},
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Chaining should NOT bypass allowlist
|
||||||
|
{`ssh server.example.com "ls && rm -rf /"`, true, "ls allowed but rm -rf not"},
|
||||||
|
{`ssh server.example.com "git status && rm -rf /"`, true, "git status allowed but rm -rf not"},
|
||||||
|
{`ssh server.example.com "cat file; shutdown"`, true, "cat allowed but shutdown not"},
|
||||||
|
{`ssh server.example.com "ls | xargs rm"`, true, "ls allowed but rm not"},
|
||||||
|
{`ssh server.example.com "ls || rm -rf /"`, true, "ls allowed but rm -rf not"},
|
||||||
|
|
||||||
|
// All subcommands allowed should work
|
||||||
|
{`ssh server.example.com "ls && cat file"`, false, "both ls and cat allowed"},
|
||||||
|
{`ssh server.example.com "ls; cat file"`, false, "semicolon chain with allowed commands"},
|
||||||
|
{`ssh server.example.com "ls | cat"`, false, "pipe with allowed commands"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSSHCommand_CommandChainingDenylist(t *testing.T) {
|
||||||
|
// Test command chaining in denylist mode
|
||||||
|
cfg := &config.Config{
|
||||||
|
SSH: config.SSHConfig{
|
||||||
|
AllowedHosts: []string{"*.example.com"},
|
||||||
|
AllowAllCommands: true,
|
||||||
|
DeniedCommands: []string{"rm -rf", "shutdown"},
|
||||||
|
},
|
||||||
|
Command: config.CommandConfig{
|
||||||
|
UseDefaults: boolPtr(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
shouldBlock bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
// Chaining should still catch denied commands
|
||||||
|
{`ssh server.example.com "ls && rm -rf /"`, true, "rm -rf in chain blocked"},
|
||||||
|
{`ssh server.example.com "cat file; shutdown"`, true, "shutdown in chain blocked"},
|
||||||
|
|
||||||
|
// Chains without denied commands should work
|
||||||
|
{`ssh server.example.com "ls && cat && grep foo"`, false, "chain without denied commands"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
err := CheckSSHCommand(tt.command, cfg)
|
||||||
|
if tt.shouldBlock && err == nil {
|
||||||
|
t.Errorf("expected SSH command %q to be blocked", tt.command)
|
||||||
|
}
|
||||||
|
if !tt.shouldBlock && err != nil {
|
||||||
|
t.Errorf("expected SSH command %q to be allowed, got: %v", tt.command, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user