feat: add domain-based outbound filtering with allowedDomains/deniedDomains
Add NetworkConfig.AllowedDomains and DeniedDomains fields for controlling outbound connections by hostname. Deny rules are checked first (deny wins). When AllowedDomains is set, only matching domains are permitted. When only DeniedDomains is set, all domains except denied ones are allowed. Implement FilteringProxy that wraps gost HTTP proxy with domain enforcement via AllowConnect callback. Skip GreyHaven proxy/DNS defaults
This commit is contained in:
@@ -26,8 +26,10 @@ type Config struct {
|
||||
|
||||
// NetworkConfig defines network restrictions.
|
||||
type NetworkConfig struct {
|
||||
ProxyURL string `json:"proxyUrl,omitempty"` // External SOCKS5 proxy (e.g. socks5://host:1080)
|
||||
DnsAddr string `json:"dnsAddr,omitempty"` // DNS server address on host (e.g. localhost:3153)
|
||||
ProxyURL string `json:"proxyUrl,omitempty"` // External SOCKS5 proxy (e.g. socks5://host:1080)
|
||||
DnsAddr string `json:"dnsAddr,omitempty"` // DNS server address on host (e.g. localhost:3153)
|
||||
AllowedDomains []string `json:"allowedDomains,omitempty"` // Domains to allow outbound connections to (supports wildcards)
|
||||
DeniedDomains []string `json:"deniedDomains,omitempty"` // Domains to deny outbound connections to (checked before allowed)
|
||||
AllowUnixSockets []string `json:"allowUnixSockets,omitempty"`
|
||||
AllowAllUnixSockets bool `json:"allowAllUnixSockets,omitempty"`
|
||||
AllowLocalBinding bool `json:"allowLocalBinding,omitempty"`
|
||||
@@ -209,6 +211,17 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, domain := range c.Network.AllowedDomains {
|
||||
if err := validateDomainPattern(domain); err != nil {
|
||||
return fmt.Errorf("invalid network.allowedDomains %q: %w", domain, err)
|
||||
}
|
||||
}
|
||||
for _, domain := range c.Network.DeniedDomains {
|
||||
if err := validateDomainPattern(domain); err != nil {
|
||||
return fmt.Errorf("invalid network.deniedDomains %q: %w", domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(c.Filesystem.AllowRead, "") {
|
||||
return errors.New("filesystem.allowRead contains empty path")
|
||||
}
|
||||
@@ -384,6 +397,84 @@ func matchGlob(s, pattern string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// validateDomainPattern validates a domain pattern for allowedDomains/deniedDomains.
|
||||
// Rejects patterns with protocol, path, port, or empty strings.
|
||||
func validateDomainPattern(pattern string) error {
|
||||
if pattern == "" {
|
||||
return errors.New("empty domain pattern")
|
||||
}
|
||||
|
||||
// Allow wildcard-all
|
||||
if pattern == "*" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject patterns with protocol
|
||||
if strings.Contains(pattern, "://") {
|
||||
return errors.New("domain pattern cannot contain protocol (remove http:// or https://)")
|
||||
}
|
||||
|
||||
// Reject patterns with path
|
||||
if strings.Contains(pattern, "/") {
|
||||
return errors.New("domain pattern cannot contain path")
|
||||
}
|
||||
|
||||
// Reject patterns with port
|
||||
if strings.Contains(pattern, ":") {
|
||||
return errors.New("domain pattern cannot contain port")
|
||||
}
|
||||
|
||||
// Reject patterns with @
|
||||
if strings.Contains(pattern, "@") {
|
||||
return errors.New("domain pattern cannot contain username")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDomainAllowed checks if a hostname is allowed by the domain filtering rules.
|
||||
// Strips port from hostname. Deny rules are checked first (deny wins).
|
||||
// If AllowedDomains is set, the domain must match at least one allowed pattern.
|
||||
// If only DeniedDomains is set, all domains except denied ones are allowed.
|
||||
func (n *NetworkConfig) IsDomainAllowed(hostname string) bool {
|
||||
// Strip port if present
|
||||
if host, _, found := strings.Cut(hostname, ":"); found {
|
||||
hostname = host
|
||||
}
|
||||
|
||||
// Check denied domains first (deny wins)
|
||||
for _, pattern := range n.DeniedDomains {
|
||||
if MatchesHost(hostname, pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// If no allowed domains configured, allow all (only deny list is active)
|
||||
if len(n.AllowedDomains) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check allowed domains
|
||||
for _, pattern := range n.AllowedDomains {
|
||||
if MatchesHost(hostname, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Not in allow list
|
||||
return false
|
||||
}
|
||||
|
||||
// HasDomainFiltering returns true when domain filtering is configured.
|
||||
func (n *NetworkConfig) HasDomainFiltering() bool {
|
||||
return len(n.AllowedDomains) > 0 || len(n.DeniedDomains) > 0
|
||||
}
|
||||
|
||||
// IsWildcardAllow returns true when AllowedDomains contains "*" (allow all).
|
||||
func (n *NetworkConfig) IsWildcardAllow() bool {
|
||||
return slices.Contains(n.AllowedDomains, "*")
|
||||
}
|
||||
|
||||
// Merge combines a base config with an override config.
|
||||
// Values in override take precedence. Slice fields are appended (base + override).
|
||||
// The Extends field is cleared in the result since inheritance has been resolved.
|
||||
@@ -411,6 +502,10 @@ func Merge(base, override *Config) *Config {
|
||||
ProxyURL: mergeString(base.Network.ProxyURL, override.Network.ProxyURL),
|
||||
DnsAddr: mergeString(base.Network.DnsAddr, override.Network.DnsAddr),
|
||||
|
||||
// Append domain slices
|
||||
AllowedDomains: mergeStrings(base.Network.AllowedDomains, override.Network.AllowedDomains),
|
||||
DeniedDomains: mergeStrings(base.Network.DeniedDomains, override.Network.DeniedDomains),
|
||||
|
||||
// Append slices (base first, then override additions)
|
||||
AllowUnixSockets: mergeStrings(base.Network.AllowUnixSockets, override.Network.AllowUnixSockets),
|
||||
|
||||
|
||||
@@ -668,6 +668,265 @@ func TestMergeSSHConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateDomainPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid patterns
|
||||
{"simple domain", "example.com", false},
|
||||
{"subdomain", "api.example.com", false},
|
||||
{"wildcard prefix", "*.example.com", false},
|
||||
{"wildcard all", "*", false},
|
||||
{"wildcard middle", "api-*.example.com", false},
|
||||
{"localhost", "localhost", false},
|
||||
|
||||
// Invalid patterns
|
||||
{"empty", "", true},
|
||||
{"with protocol http", "http://example.com", true},
|
||||
{"with protocol https", "https://example.com", true},
|
||||
{"with path", "example.com/path", true},
|
||||
{"with port", "example.com:443", true},
|
||||
{"with username", "user@example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDomainPattern(tt.pattern)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateDomainPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDomainAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedDomains []string
|
||||
deniedDomains []string
|
||||
hostname string
|
||||
want bool
|
||||
}{
|
||||
// Allow list only
|
||||
{"allowed exact match", []string{"api.example.com"}, nil, "api.example.com", true},
|
||||
{"allowed wildcard match", []string{"*.example.com"}, nil, "api.example.com", true},
|
||||
{"allowed no match", []string{"api.example.com"}, nil, "other.com", false},
|
||||
{"allowed wildcard all", []string{"*"}, nil, "anything.com", true},
|
||||
{"allowed multiple", []string{"api.example.com", "cdn.example.com"}, nil, "cdn.example.com", true},
|
||||
|
||||
// Deny list only (no allow list = allow all except denied)
|
||||
{"denied only - not denied", nil, []string{"evil.com"}, "good.com", true},
|
||||
{"denied only - denied match", nil, []string{"evil.com"}, "evil.com", false},
|
||||
{"denied only - wildcard deny", nil, []string{"*.evil.com"}, "sub.evil.com", false},
|
||||
|
||||
// Both allow and deny (deny wins)
|
||||
{"deny wins over allow", []string{"*.example.com"}, []string{"secret.example.com"}, "secret.example.com", false},
|
||||
{"allow works when not denied", []string{"*.example.com"}, []string{"secret.example.com"}, "api.example.com", true},
|
||||
{"not in allow after deny check", []string{"*.example.com"}, []string{"evil.com"}, "other.com", false},
|
||||
|
||||
// Port stripping
|
||||
{"strips port", []string{"api.example.com"}, nil, "api.example.com:443", true},
|
||||
{"strips port denied", nil, []string{"evil.com"}, "evil.com:8080", false},
|
||||
|
||||
// Empty allow list = allow all (only deny active)
|
||||
{"empty allow list allows all", nil, nil, "anything.com", true},
|
||||
{"empty both allows all", []string{}, []string{}, "anything.com", true},
|
||||
|
||||
// Case insensitive
|
||||
{"case insensitive", []string{"API.Example.COM"}, nil, "api.example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nc := &NetworkConfig{
|
||||
AllowedDomains: tt.allowedDomains,
|
||||
DeniedDomains: tt.deniedDomains,
|
||||
}
|
||||
got := nc.IsDomainAllowed(tt.hostname)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsDomainAllowed(%q) = %v, want %v (allowed=%v, denied=%v)",
|
||||
tt.hostname, got, tt.want, tt.allowedDomains, tt.deniedDomains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasDomainFiltering(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedDomains []string
|
||||
deniedDomains []string
|
||||
want bool
|
||||
}{
|
||||
{"no domains", nil, nil, false},
|
||||
{"empty slices", []string{}, []string{}, false},
|
||||
{"allowed only", []string{"example.com"}, nil, true},
|
||||
{"denied only", nil, []string{"evil.com"}, true},
|
||||
{"both", []string{"example.com"}, []string{"evil.com"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nc := &NetworkConfig{
|
||||
AllowedDomains: tt.allowedDomains,
|
||||
DeniedDomains: tt.deniedDomains,
|
||||
}
|
||||
got := nc.HasDomainFiltering()
|
||||
if got != tt.want {
|
||||
t.Errorf("HasDomainFiltering() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsWildcardAllow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedDomains []string
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"empty", []string{}, false},
|
||||
{"specific domain", []string{"example.com"}, false},
|
||||
{"wildcard", []string{"*"}, true},
|
||||
{"wildcard among others", []string{"example.com", "*"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nc := &NetworkConfig{AllowedDomains: tt.allowedDomains}
|
||||
got := nc.IsWildcardAllow()
|
||||
if got != tt.want {
|
||||
t.Errorf("IsWildcardAllow() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidateDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid allowed domains",
|
||||
config: Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"api.example.com", "*.cdn.example.com"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid denied domains",
|
||||
config: Config{
|
||||
Network: NetworkConfig{
|
||||
DeniedDomains: []string{"evil.com", "*.malware.com"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid allowed domain with protocol",
|
||||
config: Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"https://example.com"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid denied domain with port",
|
||||
config: Config{
|
||||
Network: NetworkConfig{
|
||||
DeniedDomains: []string{"example.com:443"},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid allowed domain empty",
|
||||
config: Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{""},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDomainConfig(t *testing.T) {
|
||||
t.Run("merge allowed domains", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"api.example.com"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"cdn.example.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.Network.AllowedDomains) != 2 {
|
||||
t.Errorf("expected 2 allowed domains, got %d: %v",
|
||||
len(result.Network.AllowedDomains), result.Network.AllowedDomains)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge denied domains", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Network: NetworkConfig{
|
||||
DeniedDomains: []string{"evil.com"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Network: NetworkConfig{
|
||||
DeniedDomains: []string{"malware.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.Network.DeniedDomains) != 2 {
|
||||
t.Errorf("expected 2 denied domains, got %d: %v",
|
||||
len(result.Network.DeniedDomains), result.Network.DeniedDomains)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge deduplicates domains", func(t *testing.T) {
|
||||
base := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"api.example.com"},
|
||||
},
|
||||
}
|
||||
override := &Config{
|
||||
Network: NetworkConfig{
|
||||
AllowedDomains: []string{"api.example.com"},
|
||||
},
|
||||
}
|
||||
result := Merge(base, override)
|
||||
|
||||
if len(result.Network.AllowedDomains) != 1 {
|
||||
t.Errorf("expected 1 allowed domain (deduped), got %d: %v",
|
||||
len(result.Network.AllowedDomains), result.Network.AllowedDomains)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateProxyURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user