Add support for config inheritance
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
// Config is the main configuration for fence.
|
||||
type Config struct {
|
||||
Extends string `json:"extends,omitempty"`
|
||||
Network NetworkConfig `json:"network"`
|
||||
Filesystem FilesystemConfig `json:"filesystem"`
|
||||
Command CommandConfig `json:"command"`
|
||||
@@ -247,3 +248,109 @@ func MatchesDomain(hostname, pattern string) bool {
|
||||
// Exact match
|
||||
return hostname == pattern
|
||||
}
|
||||
|
||||
// Merge combines a base config with an override config.
|
||||
// Values in override take precedence. Slice fields are appended (base + override).
|
||||
// The Extends field is cleared in the result since inheritance has been resolved.
|
||||
func Merge(base, override *Config) *Config {
|
||||
if base == nil {
|
||||
if override == nil {
|
||||
return Default()
|
||||
}
|
||||
result := *override
|
||||
result.Extends = ""
|
||||
return &result
|
||||
}
|
||||
if override == nil {
|
||||
result := *base
|
||||
result.Extends = ""
|
||||
return &result
|
||||
}
|
||||
|
||||
result := &Config{
|
||||
// AllowPty: true if either config enables it
|
||||
AllowPty: base.AllowPty || override.AllowPty,
|
||||
|
||||
Network: NetworkConfig{
|
||||
// Append slices (base first, then override additions)
|
||||
AllowedDomains: mergeStrings(base.Network.AllowedDomains, override.Network.AllowedDomains),
|
||||
DeniedDomains: mergeStrings(base.Network.DeniedDomains, override.Network.DeniedDomains),
|
||||
AllowUnixSockets: mergeStrings(base.Network.AllowUnixSockets, override.Network.AllowUnixSockets),
|
||||
|
||||
// Boolean fields: override wins if set, otherwise base
|
||||
AllowAllUnixSockets: base.Network.AllowAllUnixSockets || override.Network.AllowAllUnixSockets,
|
||||
AllowLocalBinding: base.Network.AllowLocalBinding || override.Network.AllowLocalBinding,
|
||||
|
||||
// Pointer fields: override wins if set, otherwise base
|
||||
AllowLocalOutbound: mergeOptionalBool(base.Network.AllowLocalOutbound, override.Network.AllowLocalOutbound),
|
||||
|
||||
// Port fields: override wins if non-zero
|
||||
HTTPProxyPort: mergeInt(base.Network.HTTPProxyPort, override.Network.HTTPProxyPort),
|
||||
SOCKSProxyPort: mergeInt(base.Network.SOCKSProxyPort, override.Network.SOCKSProxyPort),
|
||||
},
|
||||
|
||||
Filesystem: FilesystemConfig{
|
||||
// Append slices
|
||||
DenyRead: mergeStrings(base.Filesystem.DenyRead, override.Filesystem.DenyRead),
|
||||
AllowWrite: mergeStrings(base.Filesystem.AllowWrite, override.Filesystem.AllowWrite),
|
||||
DenyWrite: mergeStrings(base.Filesystem.DenyWrite, override.Filesystem.DenyWrite),
|
||||
|
||||
// Boolean fields: override wins if set
|
||||
AllowGitConfig: base.Filesystem.AllowGitConfig || override.Filesystem.AllowGitConfig,
|
||||
},
|
||||
|
||||
Command: CommandConfig{
|
||||
// Append slices
|
||||
Deny: mergeStrings(base.Command.Deny, override.Command.Deny),
|
||||
Allow: mergeStrings(base.Command.Allow, override.Command.Allow),
|
||||
|
||||
// Pointer field: override wins if set
|
||||
UseDefaults: mergeOptionalBool(base.Command.UseDefaults, override.Command.UseDefaults),
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeStrings appends two string slices, removing duplicates.
|
||||
func mergeStrings(base, override []string) []string {
|
||||
if len(base) == 0 {
|
||||
return override
|
||||
}
|
||||
if len(override) == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(base))
|
||||
result := make([]string, 0, len(base)+len(override))
|
||||
|
||||
for _, s := range base {
|
||||
if !seen[s] {
|
||||
seen[s] = true
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
for _, s := range override {
|
||||
if !seen[s] {
|
||||
seen[s] = true
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeOptionalBool returns override if non-nil, otherwise base.
|
||||
func mergeOptionalBool(base, override *bool) *bool {
|
||||
if override != nil {
|
||||
return override
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// mergeInt returns override if non-zero, otherwise base.
|
||||
func mergeInt(base, override int) int {
|
||||
if override != 0 {
|
||||
return override
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
@@ -291,3 +291,192 @@ func TestDefaultConfigPath(t *testing.T) {
|
||||
t.Errorf("DefaultConfigPath() = %q, expected to end with .fence.json", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
t.Run("nil base", func(t *testing.T) {
|
||||
override := &Config{
|
||||
AllowPty: true,
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"example.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(nil, override)
|
||||
if !result.AllowPty {
|
||||
t.Error("expected AllowPty to be true")
|
||||
}
|
||||
if len(result.Network.AllowedDomains) != 1 || result.Network.AllowedDomains[0] != "example.com" {
|
||||
t.Error("expected AllowedDomains to be [example.com]")
|
||||
}
|
||||
if result.Extends != "" {
|
||||
t.Error("expected Extends to be cleared")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil override", func(t *testing.T) {
|
||||
base := &Config{
|
||||
AllowPty: true,
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"example.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, nil)
|
||||
if !result.AllowPty {
|
||||
t.Error("expected AllowPty to be true")
|
||||
}
|
||||
if len(result.Network.AllowedDomains) != 1 {
|
||||
t.Error("expected AllowedDomains to be [example.com]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("both nil", func(t *testing.T) {
|
||||
result := Merge(nil, nil)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge allowed domains", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"github.com", "api.github.com"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Extends: "base-template",
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"private-registry.company.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
// Should have all three domains
|
||||
if len(result.Network.AllowedDomains) != 3 {
|
||||
t.Errorf("expected 3 allowed domains, got %d: %v", len(result.Network.AllowedDomains), result.Network.AllowedDomains)
|
||||
}
|
||||
|
||||
// Extends should be cleared
|
||||
if result.Extends != "" {
|
||||
t.Errorf("expected Extends to be cleared, got %q", result.Extends)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deduplicate merged domains", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"github.com", "example.com"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"github.com", "new.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
// Should deduplicate
|
||||
if len(result.Network.AllowedDomains) != 3 {
|
||||
t.Errorf("expected 3 domains (deduped), got %d: %v", len(result.Network.AllowedDomains), result.Network.AllowedDomains)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge boolean flags", func(t *testing.T) {
|
||||
base := &Config{
|
||||
AllowPty: false,
|
||||
Network: NetworkConfig{
|
||||
AllowLocalBinding: true,
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
AllowPty: true,
|
||||
Network: NetworkConfig{
|
||||
AllowLocalOutbound: boolPtr(true),
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if !result.AllowPty {
|
||||
t.Error("expected AllowPty to be true (from override)")
|
||||
}
|
||||
if !result.Network.AllowLocalBinding {
|
||||
t.Error("expected AllowLocalBinding to be true (from base)")
|
||||
}
|
||||
if result.Network.AllowLocalOutbound == nil || !*result.Network.AllowLocalOutbound {
|
||||
t.Error("expected AllowLocalOutbound to be true (from override)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge command config", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Command: CommandConfig{
|
||||
Deny: []string{"git push", "rm -rf"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Command: CommandConfig{
|
||||
Deny: []string{"sudo"},
|
||||
Allow: []string{"git status"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.Command.Deny) != 3 {
|
||||
t.Errorf("expected 3 denied commands, got %d", len(result.Command.Deny))
|
||||
}
|
||||
if len(result.Command.Allow) != 1 {
|
||||
t.Errorf("expected 1 allowed command, got %d", len(result.Command.Allow))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge filesystem config", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Filesystem: FilesystemConfig{
|
||||
AllowWrite: []string{"."},
|
||||
DenyRead: []string{"~/.ssh/**"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Filesystem: FilesystemConfig{
|
||||
AllowWrite: []string{"/tmp"},
|
||||
DenyWrite: []string{".env"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.Filesystem.AllowWrite) != 2 {
|
||||
t.Errorf("expected 2 write paths, got %d", len(result.Filesystem.AllowWrite))
|
||||
}
|
||||
if len(result.Filesystem.DenyRead) != 1 {
|
||||
t.Errorf("expected 1 deny read path, got %d", len(result.Filesystem.DenyRead))
|
||||
}
|
||||
if len(result.Filesystem.DenyWrite) != 1 {
|
||||
t.Errorf("expected 1 deny write path, got %d", len(result.Filesystem.DenyWrite))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("override ports", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Network: NetworkConfig{
|
||||
HTTPProxyPort: 8080,
|
||||
SOCKSProxyPort: 1080,
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Network: NetworkConfig{
|
||||
HTTPProxyPort: 9090, // override
|
||||
// SOCKSProxyPort not set, should keep base
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if result.Network.HTTPProxyPort != 9090 {
|
||||
t.Errorf("expected HTTPProxyPort 9090, got %d", result.Network.HTTPProxyPort)
|
||||
}
|
||||
if result.Network.SOCKSProxyPort != 1080 {
|
||||
t.Errorf("expected SOCKSProxyPort 1080, got %d", result.Network.SOCKSProxyPort)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user