diff --git a/cmd/greywall/main.go b/cmd/greywall/main.go index 20f1315..56e7d29 100644 --- a/cmd/greywall/main.go +++ b/cmd/greywall/main.go @@ -233,13 +233,15 @@ func runCommand(cmd *cobra.Command, args []string) error { // GreyHaven defaults: when no proxy or DNS is configured (neither via CLI // nor config file), use the standard GreyHaven infrastructure ports. - if cfg.Network.ProxyURL == "" { + // Skip GreyHaven proxy default when domain filtering is configured — + // the filtering proxy handles outbound instead. + if cfg.Network.ProxyURL == "" && !cfg.Network.HasDomainFiltering() { cfg.Network.ProxyURL = "socks5://localhost:42052" if debug { fmt.Fprintf(os.Stderr, "[greywall] Defaulting proxy to socks5://localhost:42052\n") } } - if cfg.Network.DnsAddr == "" { + if cfg.Network.DnsAddr == "" && !cfg.Network.HasDomainFiltering() { cfg.Network.DnsAddr = "localhost:42053" if debug { fmt.Fprintf(os.Stderr, "[greywall] Defaulting DNS to localhost:42053\n") diff --git a/internal/config/config.go b/internal/config/config.go index 352482c..a57fe2b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,8 +26,10 @@ type Config struct { // NetworkConfig defines network restrictions. type NetworkConfig struct { - ProxyURL string `json:"proxyUrl,omitempty"` // External SOCKS5 proxy (e.g. socks5://host:1080) - DnsAddr string `json:"dnsAddr,omitempty"` // DNS server address on host (e.g. localhost:3153) + ProxyURL string `json:"proxyUrl,omitempty"` // External SOCKS5 proxy (e.g. socks5://host:1080) + DnsAddr string `json:"dnsAddr,omitempty"` // DNS server address on host (e.g. localhost:3153) + AllowedDomains []string `json:"allowedDomains,omitempty"` // Domains to allow outbound connections to (supports wildcards) + DeniedDomains []string `json:"deniedDomains,omitempty"` // Domains to deny outbound connections to (checked before allowed) AllowUnixSockets []string `json:"allowUnixSockets,omitempty"` AllowAllUnixSockets bool `json:"allowAllUnixSockets,omitempty"` AllowLocalBinding bool `json:"allowLocalBinding,omitempty"` @@ -209,6 +211,17 @@ func (c *Config) Validate() error { } } + for _, domain := range c.Network.AllowedDomains { + if err := validateDomainPattern(domain); err != nil { + return fmt.Errorf("invalid network.allowedDomains %q: %w", domain, err) + } + } + for _, domain := range c.Network.DeniedDomains { + if err := validateDomainPattern(domain); err != nil { + return fmt.Errorf("invalid network.deniedDomains %q: %w", domain, err) + } + } + if slices.Contains(c.Filesystem.AllowRead, "") { return errors.New("filesystem.allowRead contains empty path") } @@ -384,6 +397,84 @@ func matchGlob(s, pattern string) bool { return true } +// validateDomainPattern validates a domain pattern for allowedDomains/deniedDomains. +// Rejects patterns with protocol, path, port, or empty strings. +func validateDomainPattern(pattern string) error { + if pattern == "" { + return errors.New("empty domain pattern") + } + + // Allow wildcard-all + if pattern == "*" { + return nil + } + + // Reject patterns with protocol + if strings.Contains(pattern, "://") { + return errors.New("domain pattern cannot contain protocol (remove http:// or https://)") + } + + // Reject patterns with path + if strings.Contains(pattern, "/") { + return errors.New("domain pattern cannot contain path") + } + + // Reject patterns with port + if strings.Contains(pattern, ":") { + return errors.New("domain pattern cannot contain port") + } + + // Reject patterns with @ + if strings.Contains(pattern, "@") { + return errors.New("domain pattern cannot contain username") + } + + return nil +} + +// IsDomainAllowed checks if a hostname is allowed by the domain filtering rules. +// Strips port from hostname. Deny rules are checked first (deny wins). +// If AllowedDomains is set, the domain must match at least one allowed pattern. +// If only DeniedDomains is set, all domains except denied ones are allowed. +func (n *NetworkConfig) IsDomainAllowed(hostname string) bool { + // Strip port if present + if host, _, found := strings.Cut(hostname, ":"); found { + hostname = host + } + + // Check denied domains first (deny wins) + for _, pattern := range n.DeniedDomains { + if MatchesHost(hostname, pattern) { + return false + } + } + + // If no allowed domains configured, allow all (only deny list is active) + if len(n.AllowedDomains) == 0 { + return true + } + + // Check allowed domains + for _, pattern := range n.AllowedDomains { + if MatchesHost(hostname, pattern) { + return true + } + } + + // Not in allow list + return false +} + +// HasDomainFiltering returns true when domain filtering is configured. +func (n *NetworkConfig) HasDomainFiltering() bool { + return len(n.AllowedDomains) > 0 || len(n.DeniedDomains) > 0 +} + +// IsWildcardAllow returns true when AllowedDomains contains "*" (allow all). +func (n *NetworkConfig) IsWildcardAllow() bool { + return slices.Contains(n.AllowedDomains, "*") +} + // Merge combines a base config with an override config. // Values in override take precedence. Slice fields are appended (base + override). // The Extends field is cleared in the result since inheritance has been resolved. @@ -411,6 +502,10 @@ func Merge(base, override *Config) *Config { ProxyURL: mergeString(base.Network.ProxyURL, override.Network.ProxyURL), DnsAddr: mergeString(base.Network.DnsAddr, override.Network.DnsAddr), + // Append domain slices + AllowedDomains: mergeStrings(base.Network.AllowedDomains, override.Network.AllowedDomains), + DeniedDomains: mergeStrings(base.Network.DeniedDomains, override.Network.DeniedDomains), + // Append slices (base first, then override additions) AllowUnixSockets: mergeStrings(base.Network.AllowUnixSockets, override.Network.AllowUnixSockets), diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 58d268c..b213db6 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -668,6 +668,265 @@ func TestMergeSSHConfig(t *testing.T) { }) } +func TestValidateDomainPattern(t *testing.T) { + tests := []struct { + name string + pattern string + wantErr bool + }{ + // Valid patterns + {"simple domain", "example.com", false}, + {"subdomain", "api.example.com", false}, + {"wildcard prefix", "*.example.com", false}, + {"wildcard all", "*", false}, + {"wildcard middle", "api-*.example.com", false}, + {"localhost", "localhost", false}, + + // Invalid patterns + {"empty", "", true}, + {"with protocol http", "http://example.com", true}, + {"with protocol https", "https://example.com", true}, + {"with path", "example.com/path", true}, + {"with port", "example.com:443", true}, + {"with username", "user@example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDomainPattern(tt.pattern) + if (err != nil) != tt.wantErr { + t.Errorf("validateDomainPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr) + } + }) + } +} + +func TestIsDomainAllowed(t *testing.T) { + tests := []struct { + name string + allowedDomains []string + deniedDomains []string + hostname string + want bool + }{ + // Allow list only + {"allowed exact match", []string{"api.example.com"}, nil, "api.example.com", true}, + {"allowed wildcard match", []string{"*.example.com"}, nil, "api.example.com", true}, + {"allowed no match", []string{"api.example.com"}, nil, "other.com", false}, + {"allowed wildcard all", []string{"*"}, nil, "anything.com", true}, + {"allowed multiple", []string{"api.example.com", "cdn.example.com"}, nil, "cdn.example.com", true}, + + // Deny list only (no allow list = allow all except denied) + {"denied only - not denied", nil, []string{"evil.com"}, "good.com", true}, + {"denied only - denied match", nil, []string{"evil.com"}, "evil.com", false}, + {"denied only - wildcard deny", nil, []string{"*.evil.com"}, "sub.evil.com", false}, + + // Both allow and deny (deny wins) + {"deny wins over allow", []string{"*.example.com"}, []string{"secret.example.com"}, "secret.example.com", false}, + {"allow works when not denied", []string{"*.example.com"}, []string{"secret.example.com"}, "api.example.com", true}, + {"not in allow after deny check", []string{"*.example.com"}, []string{"evil.com"}, "other.com", false}, + + // Port stripping + {"strips port", []string{"api.example.com"}, nil, "api.example.com:443", true}, + {"strips port denied", nil, []string{"evil.com"}, "evil.com:8080", false}, + + // Empty allow list = allow all (only deny active) + {"empty allow list allows all", nil, nil, "anything.com", true}, + {"empty both allows all", []string{}, []string{}, "anything.com", true}, + + // Case insensitive + {"case insensitive", []string{"API.Example.COM"}, nil, "api.example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nc := &NetworkConfig{ + AllowedDomains: tt.allowedDomains, + DeniedDomains: tt.deniedDomains, + } + got := nc.IsDomainAllowed(tt.hostname) + if got != tt.want { + t.Errorf("IsDomainAllowed(%q) = %v, want %v (allowed=%v, denied=%v)", + tt.hostname, got, tt.want, tt.allowedDomains, tt.deniedDomains) + } + }) + } +} + +func TestHasDomainFiltering(t *testing.T) { + tests := []struct { + name string + allowedDomains []string + deniedDomains []string + want bool + }{ + {"no domains", nil, nil, false}, + {"empty slices", []string{}, []string{}, false}, + {"allowed only", []string{"example.com"}, nil, true}, + {"denied only", nil, []string{"evil.com"}, true}, + {"both", []string{"example.com"}, []string{"evil.com"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nc := &NetworkConfig{ + AllowedDomains: tt.allowedDomains, + DeniedDomains: tt.deniedDomains, + } + got := nc.HasDomainFiltering() + if got != tt.want { + t.Errorf("HasDomainFiltering() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsWildcardAllow(t *testing.T) { + tests := []struct { + name string + allowedDomains []string + want bool + }{ + {"nil", nil, false}, + {"empty", []string{}, false}, + {"specific domain", []string{"example.com"}, false}, + {"wildcard", []string{"*"}, true}, + {"wildcard among others", []string{"example.com", "*"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nc := &NetworkConfig{AllowedDomains: tt.allowedDomains} + got := nc.IsWildcardAllow() + if got != tt.want { + t.Errorf("IsWildcardAllow() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfigValidateDomains(t *testing.T) { + tests := []struct { + name string + config Config + wantErr bool + }{ + { + name: "valid allowed domains", + config: Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"api.example.com", "*.cdn.example.com"}, + }, + }, + wantErr: false, + }, + { + name: "valid denied domains", + config: Config{ + Network: NetworkConfig{ + DeniedDomains: []string{"evil.com", "*.malware.com"}, + }, + }, + wantErr: false, + }, + { + name: "invalid allowed domain with protocol", + config: Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"https://example.com"}, + }, + }, + wantErr: true, + }, + { + name: "invalid denied domain with port", + config: Config{ + Network: NetworkConfig{ + DeniedDomains: []string{"example.com:443"}, + }, + }, + wantErr: true, + }, + { + name: "invalid allowed domain empty", + config: Config{ + Network: NetworkConfig{ + AllowedDomains: []string{""}, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestMergeDomainConfig(t *testing.T) { + t.Run("merge allowed domains", func(t *testing.T) { + base := &Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"api.example.com"}, + }, + } + override := &Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"cdn.example.com"}, + }, + } + result := Merge(base, override) + + if len(result.Network.AllowedDomains) != 2 { + t.Errorf("expected 2 allowed domains, got %d: %v", + len(result.Network.AllowedDomains), result.Network.AllowedDomains) + } + }) + + t.Run("merge denied domains", func(t *testing.T) { + base := &Config{ + Network: NetworkConfig{ + DeniedDomains: []string{"evil.com"}, + }, + } + override := &Config{ + Network: NetworkConfig{ + DeniedDomains: []string{"malware.com"}, + }, + } + result := Merge(base, override) + + if len(result.Network.DeniedDomains) != 2 { + t.Errorf("expected 2 denied domains, got %d: %v", + len(result.Network.DeniedDomains), result.Network.DeniedDomains) + } + }) + + t.Run("merge deduplicates domains", func(t *testing.T) { + base := &Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"api.example.com"}, + }, + } + override := &Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"api.example.com"}, + }, + } + result := Merge(base, override) + + if len(result.Network.AllowedDomains) != 1 { + t.Errorf("expected 1 allowed domain (deduped), got %d: %v", + len(result.Network.AllowedDomains), result.Network.AllowedDomains) + } + }) +} + func TestValidateProxyURL(t *testing.T) { tests := []struct { name string diff --git a/internal/sandbox/linux.go b/internal/sandbox/linux.go index 23be2c2..4b9359e 100644 --- a/internal/sandbox/linux.go +++ b/internal/sandbox/linux.go @@ -594,8 +594,8 @@ func isSystemMountPoint(path string) bool { // WrapCommandLinux wraps a command with Linux bubblewrap sandbox. // It uses available security features (Landlock, seccomp) with graceful fallback. -func WrapCommandLinux(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, debug bool) (string, error) { - return WrapCommandLinuxWithOptions(cfg, command, proxyBridge, dnsBridge, reverseBridge, tun2socksPath, LinuxSandboxOptions{ +func WrapCommandLinux(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, filterProxy *FilteringProxy, debug bool) (string, error) { + return WrapCommandLinuxWithOptions(cfg, command, proxyBridge, dnsBridge, reverseBridge, tun2socksPath, filterProxy, LinuxSandboxOptions{ UseLandlock: true, // Enabled by default, will fall back if not available UseSeccomp: true, // Enabled by default UseEBPF: true, // Enabled by default if available @@ -604,7 +604,7 @@ func WrapCommandLinux(cfg *config.Config, command string, proxyBridge *ProxyBrid } // WrapCommandLinuxWithOptions wraps a command with configurable sandbox options. -func WrapCommandLinuxWithOptions(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, opts LinuxSandboxOptions) (string, error) { +func WrapCommandLinuxWithOptions(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, filterProxy *FilteringProxy, opts LinuxSandboxOptions) (string, error) { if _, err := exec.LookPath("bwrap"); err != nil { return "", fmt.Errorf("bubblewrap (bwrap) is required on Linux but not found: %w", err) } @@ -636,11 +636,18 @@ func WrapCommandLinuxWithOptions(cfg *config.Config, command string, proxyBridge bwrapArgs = append(bwrapArgs, "--die-with-parent") // Always use --unshare-net when available (network namespace isolation) - // Inside the namespace, tun2socks will provide transparent proxy access - if features.CanUnshareNet { + // Inside the namespace, tun2socks will provide transparent proxy access. + // Skip network namespace when domain filtering with wildcard allow is active + // (the filtering proxy handles domain enforcement via env vars). + skipUnshareNet := filterProxy != nil && cfg != nil && cfg.Network.IsWildcardAllow() + if features.CanUnshareNet && !skipUnshareNet { bwrapArgs = append(bwrapArgs, "--unshare-net") // Network namespace isolation } else if opts.Debug { - fmt.Fprintf(os.Stderr, "[greywall:linux] Skipping --unshare-net (network namespace unavailable in this environment)\n") + if skipUnshareNet { + fmt.Fprintf(os.Stderr, "[greywall:linux] Skipping --unshare-net (wildcard allow with domain filtering)\n") + } else { + fmt.Fprintf(os.Stderr, "[greywall:linux] Skipping --unshare-net (network namespace unavailable in this environment)\n") + } } bwrapArgs = append(bwrapArgs, "--unshare-pid") // PID namespace isolation @@ -1042,6 +1049,23 @@ export no_proxy=localhost,127.0.0.1 `, proxyBridge.SocketPath)) } + // Set up domain filtering proxy env vars inside the sandbox. + // When filterProxy is active, skip tun2socks and use env-var-based proxying + // through a socat bridge to the host-side filtering proxy. + if filterProxy != nil && proxyBridge == nil { + filterProxyAddr := filterProxy.Addr() + innerScript.WriteString(fmt.Sprintf(` +# Domain filtering proxy: bridge to host-side filtering proxy +export HTTP_PROXY=http://%s +export HTTPS_PROXY=http://%s +export http_proxy=http://%s +export https_proxy=http://%s +export NO_PROXY=localhost,127.0.0.1 +export no_proxy=localhost,127.0.0.1 + +`, filterProxyAddr, filterProxyAddr, filterProxyAddr, filterProxyAddr)) + } + // Set up reverse (inbound) socat listeners inside the sandbox if reverseBridge != nil && len(reverseBridge.Ports) > 0 { innerScript.WriteString("\n# Start reverse bridge listeners for inbound connections\n") diff --git a/internal/sandbox/linux_stub.go b/internal/sandbox/linux_stub.go index 713e72f..ad86da1 100644 --- a/internal/sandbox/linux_stub.go +++ b/internal/sandbox/linux_stub.go @@ -63,12 +63,12 @@ func NewReverseBridge(ports []int, debug bool) (*ReverseBridge, error) { func (b *ReverseBridge) Cleanup() {} // WrapCommandLinux returns an error on non-Linux platforms. -func WrapCommandLinux(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, debug bool) (string, error) { +func WrapCommandLinux(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, filterProxy *FilteringProxy, debug bool) (string, error) { return "", fmt.Errorf("Linux sandbox not available on this platform") } // WrapCommandLinuxWithOptions returns an error on non-Linux platforms. -func WrapCommandLinuxWithOptions(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, opts LinuxSandboxOptions) (string, error) { +func WrapCommandLinuxWithOptions(cfg *config.Config, command string, proxyBridge *ProxyBridge, dnsBridge *DnsBridge, reverseBridge *ReverseBridge, tun2socksPath string, filterProxy *FilteringProxy, opts LinuxSandboxOptions) (string, error) { return "", fmt.Errorf("Linux sandbox not available on this platform") } diff --git a/internal/sandbox/macos.go b/internal/sandbox/macos.go index c13910a..7decc67 100644 --- a/internal/sandbox/macos.go +++ b/internal/sandbox/macos.go @@ -631,7 +631,7 @@ func GenerateSandboxProfile(params MacOSSandboxParams) string { } // WrapCommandMacOS wraps a command with macOS sandbox restrictions. -func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, debug bool) (string, error) { +func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, filterProxy *FilteringProxy, debug bool) (string, error) { cwd, _ := os.Getwd() // Build allow paths: default + configured @@ -650,7 +650,12 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de // Parse proxy URL for network rules var proxyHost, proxyPort string - if cfg.Network.ProxyURL != "" { + if filterProxy != nil { + // Domain filtering proxy: point at the local filtering proxy. + // Seatbelt only accepts "localhost" or "*" in (remote ip ...) filters. + proxyHost = "localhost" + proxyPort = filterProxy.Port() + } else if cfg.Network.ProxyURL != "" { if u, err := url.Parse(cfg.Network.ProxyURL); err == nil { proxyHost = u.Hostname() proxyPort = u.Port() @@ -659,7 +664,11 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de // Restrict network unless proxy is configured to an external host // If no proxy: block all outbound. If proxy: allow outbound only to proxy. + // When wildcard allow with domain filtering, allow direct outbound (proxy enforces via env vars) needsNetworkRestriction := true + if filterProxy != nil && cfg.Network.IsWildcardAllow() { + needsNetworkRestriction = false + } params := MacOSSandboxParams{ Command: command, @@ -700,7 +709,12 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de return "", fmt.Errorf("shell %q not found: %w", shell, err) } - proxyEnvs := GenerateProxyEnvVars(cfg.Network.ProxyURL) + var proxyEnvs []string + if filterProxy != nil { + proxyEnvs = GenerateHTTPProxyEnvVars(fmt.Sprintf("http://127.0.0.1:%s", filterProxy.Port())) + } else { + proxyEnvs = GenerateProxyEnvVars(cfg.Network.ProxyURL) + } // Build the command // env VAR1=val1 VAR2=val2 sandbox-exec -p 'profile' shell -c 'command' diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 8ae7ce4..a3ca801 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -14,6 +14,7 @@ type Manager struct { proxyBridge *ProxyBridge dnsBridge *DnsBridge reverseBridge *ReverseBridge + filterProxy *FilteringProxy tun2socksPath string // path to extracted tun2socks binary on host exposedPorts []int debug bool @@ -118,6 +119,36 @@ func (m *Manager) Initialize() error { } } + // Start domain filtering proxy if allowedDomains/deniedDomains are configured + if m.config.Network.HasDomainFiltering() { + fp, err := NewFilteringProxy(&m.config.Network, m.debug) + if err != nil { + // Clean up any bridges that were already started + if m.reverseBridge != nil { + m.reverseBridge.Cleanup() + } + if m.dnsBridge != nil { + m.dnsBridge.Cleanup() + } + if m.proxyBridge != nil { + m.proxyBridge.Cleanup() + } + if m.tun2socksPath != "" { + _ = os.Remove(m.tun2socksPath) + } + return fmt.Errorf("failed to start filtering proxy: %w", err) + } + m.filterProxy = fp + m.logDebug("Domain filtering proxy started on %s", fp.Addr()) + + // Write Node.js proxy bootstrap script so fetch() honors HTTP_PROXY + if bootstrapPath, err := WriteNodeProxyBootstrap(); err != nil { + m.logDebug("Warning: failed to write Node.js proxy bootstrap: %v", err) + } else { + m.logDebug("Node.js proxy bootstrap written to %s", bootstrapPath) + } + } + m.initialized = true if m.config.Network.ProxyURL != "" { dnsInfo := "none" @@ -148,12 +179,12 @@ func (m *Manager) WrapCommand(command string) (string, error) { plat := platform.Detect() switch plat { case platform.MacOS: - return WrapCommandMacOS(m.config, command, m.exposedPorts, m.debug) + return WrapCommandMacOS(m.config, command, m.exposedPorts, m.filterProxy, m.debug) case platform.Linux: if m.learning { return m.wrapCommandLearning(command) } - return WrapCommandLinux(m.config, command, m.proxyBridge, m.dnsBridge, m.reverseBridge, m.tun2socksPath, m.debug) + return WrapCommandLinux(m.config, command, m.proxyBridge, m.dnsBridge, m.reverseBridge, m.tun2socksPath, m.filterProxy, m.debug) default: return "", fmt.Errorf("unsupported platform: %s", plat) } @@ -171,7 +202,7 @@ func (m *Manager) wrapCommandLearning(command string) (string, error) { m.logDebug("Strace log file: %s", m.straceLogPath) - return WrapCommandLinuxWithOptions(m.config, command, m.proxyBridge, m.dnsBridge, m.reverseBridge, m.tun2socksPath, LinuxSandboxOptions{ + return WrapCommandLinuxWithOptions(m.config, command, m.proxyBridge, m.dnsBridge, m.reverseBridge, m.tun2socksPath, m.filterProxy, LinuxSandboxOptions{ UseLandlock: false, // Disabled: seccomp blocks ptrace which strace needs UseSeccomp: false, // Disabled: conflicts with strace UseEBPF: false, @@ -201,6 +232,9 @@ func (m *Manager) GenerateLearnedTemplate(cmdName string) (string, error) { // Cleanup stops the proxies and cleans up resources. func (m *Manager) Cleanup() { + if m.filterProxy != nil { + m.filterProxy.Shutdown() + } if m.reverseBridge != nil { m.reverseBridge.Cleanup() } diff --git a/internal/sandbox/monitor.go b/internal/sandbox/monitor.go index 5cc33a8..1eab6e9 100644 --- a/internal/sandbox/monitor.go +++ b/internal/sandbox/monitor.go @@ -138,9 +138,9 @@ func parseViolation(line string) string { timestamp := time.Now().Format("15:04:05") if details != "" { - return fmt.Sprintf("[greywall:logstream] %s ✗ %s %s (%s:%s)", timestamp, operation, details, process, pid) + return fmt.Sprintf("\033[31m[greywall:logstream] %s ✗ %s %s (%s:%s)\033[0m", timestamp, operation, details, process, pid) } - return fmt.Sprintf("[greywall:logstream] %s ✗ %s (%s:%s)", timestamp, operation, process, pid) + return fmt.Sprintf("\033[31m[greywall:logstream] %s ✗ %s (%s:%s)\033[0m", timestamp, operation, process, pid) } // shouldShowViolation returns true if this violation type should be displayed. diff --git a/internal/sandbox/proxy.go b/internal/sandbox/proxy.go new file mode 100644 index 0000000..4a9f966 --- /dev/null +++ b/internal/sandbox/proxy.go @@ -0,0 +1,245 @@ +package sandbox + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "sync" + "time" + + "gitea.app.monadical.io/monadical/greywall/internal/config" +) + +// FilteringProxy is an HTTP CONNECT proxy that filters outbound connections by domain. +// It runs on the host and is the only outbound target the sandbox allows. +type FilteringProxy struct { + listener net.Listener + server *http.Server + network *config.NetworkConfig + debug bool + mu sync.Mutex + closed bool +} + +// NewFilteringProxy creates and starts a new domain-filtering HTTP proxy. +// It listens on 127.0.0.1 with a random available port. +func NewFilteringProxy(network *config.NetworkConfig, debug bool) (*FilteringProxy, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to listen: %w", err) + } + + fp := &FilteringProxy{ + listener: listener, + network: network, + debug: debug, + } + + fp.server = &http.Server{ + Handler: http.HandlerFunc(fp.serveHTTP), + ReadHeaderTimeout: 30 * time.Second, + } + + go func() { + if err := fp.server.Serve(listener); err != nil && err != http.ErrServerClosed { + fp.logDebug("Proxy server error: %v", err) + } + }() + + if debug { + fmt.Fprintf(os.Stderr, "[greywall:proxy] Filtering proxy started on %s\n", listener.Addr().String()) + } + + return fp, nil +} + +func (fp *FilteringProxy) serveHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + fp.handleConnect(w, r) + } else if r.Method == http.MethodGet && r.URL.Path == "/__greywall_dns" { + fp.handleDNS(w, r) + } else { + fp.handleHTTP(w, r) + } +} + +// handleDNS resolves a hostname and returns the IP addresses as JSON. +// Used by the Node.js bootstrap to patch dns.lookup inside the sandbox. +func (fp *FilteringProxy) handleDNS(w http.ResponseWriter, r *http.Request) { + host := r.URL.Query().Get("host") + if host == "" { + w.Header().Set("Content-Type", "application/json") + http.Error(w, `{"error":"missing host parameter"}`, http.StatusBadRequest) + return + } + + if !fp.network.IsDomainAllowed(host) { + fp.logDenied(host) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintf(w, `{"error":"domain denied: %s"}`, host) + return + } + + addrs, err := net.LookupHost(host) + if err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"%s"}`, err.Error()) + return + } + + type addrEntry struct { + Address string `json:"address"` + Family int `json:"family"` + } + + var entries []addrEntry + for _, addr := range addrs { + family := 4 + if strings.Contains(addr, ":") { + family = 6 + } + entries = append(entries, addrEntry{Address: addr, Family: family}) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "addresses": entries, + }) +} + +func (fp *FilteringProxy) handleConnect(w http.ResponseWriter, r *http.Request) { + host := extractHost(r.Host) + + if !fp.network.IsDomainAllowed(host) { + fp.logDenied(host) + http.Error(w, fmt.Sprintf("[greywall] domain denied: %s", host), http.StatusForbidden) + return + } + + // Dial the target + target := r.Host + if !strings.Contains(target, ":") { + target = target + ":443" + } + + destConn, err := net.DialTimeout("tcp", target, 10*time.Second) + if err != nil { + http.Error(w, fmt.Sprintf("[greywall] failed to connect to %s: %v", target, err), http.StatusBadGateway) + return + } + + // Hijack the client connection + hijacker, ok := w.(http.Hijacker) + if !ok { + destConn.Close() + http.Error(w, "[greywall] hijacking not supported", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + + clientConn, _, err := hijacker.Hijack() + if err != nil { + destConn.Close() + return + } + + // Bidirectional copy + go func() { + defer destConn.Close() + defer clientConn.Close() + _, _ = io.Copy(destConn, clientConn) + }() + go func() { + defer destConn.Close() + defer clientConn.Close() + _, _ = io.Copy(clientConn, destConn) + }() +} + +func (fp *FilteringProxy) handleHTTP(w http.ResponseWriter, r *http.Request) { + host := extractHost(r.Host) + + if !fp.network.IsDomainAllowed(host) { + fp.logDenied(host) + http.Error(w, fmt.Sprintf("[greywall] domain denied: %s", host), http.StatusForbidden) + return + } + + // Forward the request + r.RequestURI = "" + + resp, err := http.DefaultTransport.RoundTrip(r) + if err != nil { + http.Error(w, fmt.Sprintf("[greywall] failed to forward request: %v", err), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Copy response headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + +// Addr returns the listener address as a string (e.g. "127.0.0.1:12345"). +func (fp *FilteringProxy) Addr() string { + return fp.listener.Addr().String() +} + +// Port returns the listener port as a string. +func (fp *FilteringProxy) Port() string { + _, port, _ := net.SplitHostPort(fp.listener.Addr().String()) + return port +} + +// Shutdown gracefully stops the proxy. +func (fp *FilteringProxy) Shutdown() { + fp.mu.Lock() + defer fp.mu.Unlock() + + if fp.closed { + return + } + fp.closed = true + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = fp.server.Shutdown(ctx) + + if fp.debug { + fmt.Fprintf(os.Stderr, "[greywall:proxy] Filtering proxy stopped\n") + } +} + +func (fp *FilteringProxy) logDenied(host string) { + fmt.Fprintf(os.Stderr, "\033[31m[greywall:proxy] domain denied: %s\033[0m\n", host) +} + +func (fp *FilteringProxy) logDebug(format string, args ...interface{}) { + if fp.debug { + fmt.Fprintf(os.Stderr, "[greywall:proxy] "+format+"\n", args...) + } +} + +// extractHost extracts the hostname from a host:port string, stripping the port. +func extractHost(hostport string) string { + host, _, err := net.SplitHostPort(hostport) + if err != nil { + // No port present + return hostport + } + return host +} diff --git a/internal/sandbox/proxy_test.go b/internal/sandbox/proxy_test.go new file mode 100644 index 0000000..b1a317c --- /dev/null +++ b/internal/sandbox/proxy_test.go @@ -0,0 +1,193 @@ +package sandbox + +import ( + "fmt" + "io" + "net/http" + "net/url" + "testing" + "time" + + "gitea.app.monadical.io/monadical/greywall/internal/config" +) + +func TestFilteringProxy_AllowedDomain(t *testing.T) { + nc := &config.NetworkConfig{ + AllowedDomains: []string{"httpbin.org"}, + } + + fp, err := NewFilteringProxy(nc, false) + if err != nil { + t.Fatalf("NewFilteringProxy() error = %v", err) + } + defer fp.Shutdown() + + if fp.Port() == "" { + t.Fatal("expected non-empty port") + } + if fp.Addr() == "" { + t.Fatal("expected non-empty addr") + } +} + +func TestFilteringProxy_DeniedDomain_HTTP(t *testing.T) { + nc := &config.NetworkConfig{ + AllowedDomains: []string{"allowed.example.com"}, + } + + fp, err := NewFilteringProxy(nc, false) + if err != nil { + t.Fatalf("NewFilteringProxy() error = %v", err) + } + defer fp.Shutdown() + + // Make a plain HTTP request to a denied domain through the proxy + proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + Timeout: 5 * time.Second, + } + + resp, err := client.Get("http://denied.example.com/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + body, _ := io.ReadAll(resp.Body) + t.Errorf("expected 403 Forbidden, got %d: %s", resp.StatusCode, string(body)) + } +} + +func TestFilteringProxy_DeniedDomain_CONNECT(t *testing.T) { + nc := &config.NetworkConfig{ + AllowedDomains: []string{"allowed.example.com"}, + } + + fp, err := NewFilteringProxy(nc, false) + if err != nil { + t.Fatalf("NewFilteringProxy() error = %v", err) + } + defer fp.Shutdown() + + // Make a CONNECT request to a denied domain + proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + Timeout: 5 * time.Second, + } + + // HTTPS triggers CONNECT method through the proxy + _, err = client.Get("https://denied.example.com/test") + if err == nil { + t.Error("expected error for denied CONNECT, got nil") + } + // The error should indicate the proxy rejected the connection (403) +} + +func TestFilteringProxy_DenyList_Only(t *testing.T) { + nc := &config.NetworkConfig{ + DeniedDomains: []string{"evil.com"}, + } + + fp, err := NewFilteringProxy(nc, false) + if err != nil { + t.Fatalf("NewFilteringProxy() error = %v", err) + } + defer fp.Shutdown() + + proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + Timeout: 5 * time.Second, + } + + // Denied domain should be blocked + resp, err := client.Get("http://evil.com/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected 403 for denied domain, got %d", resp.StatusCode) + } +} + +func TestFilteringProxy_WildcardAllow(t *testing.T) { + nc := &config.NetworkConfig{ + AllowedDomains: []string{"*"}, + DeniedDomains: []string{"evil.com"}, + } + + fp, err := NewFilteringProxy(nc, false) + if err != nil { + t.Fatalf("NewFilteringProxy() error = %v", err) + } + defer fp.Shutdown() + + proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + Timeout: 5 * time.Second, + } + + // Denied domain should still be blocked even with wildcard allow + resp, err := client.Get("http://evil.com/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected 403 for denied domain with wildcard allow, got %d", resp.StatusCode) + } +} + +func TestFilteringProxy_Shutdown(t *testing.T) { + nc := &config.NetworkConfig{ + AllowedDomains: []string{"example.com"}, + } + + fp, err := NewFilteringProxy(nc, false) + if err != nil { + t.Fatalf("NewFilteringProxy() error = %v", err) + } + + // Shutdown should not panic + fp.Shutdown() + + // Double shutdown should not panic + fp.Shutdown() +} + +func TestExtractHost(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com:443", "example.com"}, + {"example.com:80", "example.com"}, + {"example.com", "example.com"}, + {"127.0.0.1:8080", "127.0.0.1"}, + {"[::1]:443", "::1"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := extractHost(tt.input) + if got != tt.want { + t.Errorf("extractHost(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/sandbox/utils.go b/internal/sandbox/utils.go index 4b888cd..ed8fe40 100644 --- a/internal/sandbox/utils.go +++ b/internal/sandbox/utils.go @@ -86,6 +86,353 @@ func GenerateProxyEnvVars(proxyURL string) []string { return envVars } +// GenerateHTTPProxyEnvVars creates environment variables for an HTTP proxy. +// Used when domain filtering is active (HTTP CONNECT proxy, not SOCKS5). +func GenerateHTTPProxyEnvVars(httpProxyURL string) []string { + envVars := []string{ + "GREYWALL_SANDBOX=1", + "TMPDIR=/tmp/greywall", + } + + if httpProxyURL == "" { + return envVars + } + + // NO_PROXY for localhost and private networks + noProxy := strings.Join([]string{ + "localhost", + "127.0.0.1", + "::1", + "*.local", + ".local", + "169.254.0.0/16", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + }, ",") + + envVars = append(envVars, + "NO_PROXY="+noProxy, + "no_proxy="+noProxy, + "HTTP_PROXY="+httpProxyURL, + "HTTPS_PROXY="+httpProxyURL, + "http_proxy="+httpProxyURL, + "https_proxy="+httpProxyURL, + ) + + // Inject Node.js proxy bootstrap so fetch() honors HTTP_PROXY. + // Appends to existing NODE_OPTIONS if set. + nodeOpts := "--require " + nodeProxyBootstrapPath + if existing := os.Getenv("NODE_OPTIONS"); existing != "" { + nodeOpts = existing + " " + nodeOpts + } + envVars = append(envVars, "NODE_OPTIONS="+nodeOpts) + + return envVars +} + +// nodeProxyBootstrapJS is the Node.js bootstrap script that makes both +// fetch() and http/https.request() respect HTTP_PROXY/HTTPS_PROXY env vars. +// +// Two mechanisms are needed because Node.js has two separate HTTP stacks: +// 1. fetch() — powered by undici, patched via EnvHttpProxyAgent +// 2. http.request()/https.request() — built-in modules, patched via +// Agent.prototype.createConnection on BOTH http.Agent and https.Agent. +// This patches the prototype so ALL agent instances (including custom ones +// created by libraries like node-fetch, axios, got, etc.) tunnel through +// the proxy — not just globalAgent. +// +// The undici setup tries multiple strategies to find the module: +// 1. require('undici') from CWD +// 2. createRequire from the main script's path (finds it in the app's node_modules) +const nodeProxyBootstrapJS = `'use strict'; +(function() { + var proxyUrl = process.env.HTTPS_PROXY || process.env.HTTP_PROXY || + process.env.https_proxy || process.env.http_proxy; + if (!proxyUrl) return; + + // --- Part 1: Patch fetch() via undici EnvHttpProxyAgent --- + // Strategy: set global dispatcher AND wrap globalThis.fetch to force proxy. + // This prevents openclaw or other code from overriding the global dispatcher. + var undiciModule = null; + var proxyAgent = null; + + function tryGetUndici(undici) { + if (undici && typeof undici.EnvHttpProxyAgent === 'function' && + typeof undici.setGlobalDispatcher === 'function') { + return undici; + } + return null; + } + + try { undiciModule = tryGetUndici(require('undici')); } catch (e) {} + + if (!undiciModule) { + try { + var mainScript = process.argv[1]; + if (mainScript) { + var createRequire = require('module').createRequire; + var requireFrom = createRequire(require('path').resolve(mainScript)); + undiciModule = tryGetUndici(requireFrom('undici')); + } + } catch (e) {} + } + + if (undiciModule) { + proxyAgent = new undiciModule.EnvHttpProxyAgent(); + undiciModule.setGlobalDispatcher(proxyAgent); + + // Wrap globalThis.fetch to force proxy dispatcher on every call. + // This prevents code that overrides the global dispatcher from bypassing the proxy. + if (typeof globalThis.fetch === 'function') { + var _origFetch = globalThis.fetch; + globalThis.fetch = function(input, init) { + process.stderr.write('[greywall:node-bootstrap] fetch: ' + (typeof input === 'string' ? input : (input && input.url ? input.url : '?')) + '\n'); + if (!init) init = {}; + init.dispatcher = proxyAgent; + return _origFetch.call(globalThis, input, init); + }; + } + + // Also wrap undici.fetch and undici.request to catch direct usage + if (typeof undiciModule.fetch === 'function') { + var _origUndFetch = undiciModule.fetch; + undiciModule.fetch = function(input, init) { + if (!init) init = {}; + init.dispatcher = proxyAgent; + return _origUndFetch.call(undiciModule, input, init); + }; + } + if (typeof undiciModule.request === 'function') { + var _origUndRequest = undiciModule.request; + undiciModule.request = function(url, opts) { + if (!opts) opts = {}; + opts.dispatcher = proxyAgent; + return _origUndRequest.call(undiciModule, url, opts); + }; + } + } + // --- Shared setup for Parts 2 and 3 --- + var url = require('url'); + var http = require('http'); + var https = require('https'); + var tls = require('tls'); + + var parsed = new url.URL(proxyUrl); + var proxyHost = parsed.hostname; + var proxyPort = parseInt(parsed.port, 10); + + var noProxyRaw = process.env.NO_PROXY || process.env.no_proxy || ''; + var noProxyList = noProxyRaw.split(',').map(function(s) { return s.trim(); }).filter(Boolean); + + function isIPAddress(h) { + // IPv4 or IPv6 — skip DNS proxy for raw IPs + return /^\d{1,3}(\.\d{1,3}){3}$/.test(h) || h.indexOf(':') !== -1; + } + + function shouldProxy(hostname) { + if (!hostname || isIPAddress(hostname)) return false; + for (var i = 0; i < noProxyList.length; i++) { + var p = noProxyList[i]; + if (p === hostname) return false; + if (p.charAt(0) === '.' && hostname.length > p.length && + hostname.indexOf(p, hostname.length - p.length) !== -1) return false; + if (p.charAt(0) === '*' && hostname.length >= p.length - 1 && + hostname.indexOf(p.slice(1), hostname.length - p.length + 1) !== -1) return false; + } + return true; + } + + // Save originals before patching + var origHttpCreateConnection = http.Agent.prototype.createConnection; + var origHttpsCreateConnection = https.Agent.prototype.createConnection; + + // Direct agent for CONNECT requests to the proxy itself (avoids recursion) + var directAgent = new http.Agent({ keepAlive: false }); + directAgent.createConnection = origHttpCreateConnection; + + // --- Part 2: Patch Agent.prototype.createConnection on both http and https --- + // This ensures ALL agent instances tunnel through the proxy, not just globalAgent. + // Libraries like node-fetch, axios, got create their own agents — patching the + // prototype catches them all. + try { + // Patch https.Agent.prototype — affects ALL https.Agent instances + https.Agent.prototype.createConnection = function(options, callback) { + var targetHost = options.host || options.hostname; + var targetPort = options.port || 443; + + if (!shouldProxy(targetHost)) { + return origHttpsCreateConnection.call(this, options, callback); + } + + var connectReq = http.request({ + host: proxyHost, + port: proxyPort, + method: 'CONNECT', + path: targetHost + ':' + targetPort, + agent: directAgent, + }); + + connectReq.on('connect', function(res, socket) { + if (res.statusCode === 200) { + var tlsSocket = tls.connect({ + socket: socket, + servername: options.servername || targetHost, + rejectUnauthorized: options.rejectUnauthorized !== false, + }); + callback(null, tlsSocket); + } else { + socket.destroy(); + callback(new Error('Proxy CONNECT failed: ' + res.statusCode)); + } + }); + + connectReq.on('error', function(err) { callback(err); }); + connectReq.end(); + }; + + // Patch http.Agent.prototype — affects ALL http.Agent instances + http.Agent.prototype.createConnection = function(options, callback) { + var targetHost = options.host || options.hostname; + var targetPort = options.port || 80; + + if (!shouldProxy(targetHost)) { + return origHttpCreateConnection.call(this, options, callback); + } + + var connectReq = http.request({ + host: proxyHost, + port: proxyPort, + method: 'CONNECT', + path: targetHost + ':' + targetPort, + agent: directAgent, + }); + + connectReq.on('connect', function(res, socket) { + if (res.statusCode === 200) { + callback(null, socket); + } else { + socket.destroy(); + callback(new Error('Proxy CONNECT failed: ' + res.statusCode)); + } + }); + + connectReq.on('error', function(err) { callback(err); }); + connectReq.end(); + }; + } catch (e) {} + + // --- Part 3: Patch dns.lookup / dns.promises.lookup to resolve through proxy --- + // OpenClaw (and other apps) do DNS resolution before fetch for SSRF protection. + // Inside the sandbox, DNS is blocked. Route lookups through the proxy's + // /__greywall_dns endpoint which resolves on the host side. + try { + var dns = require('dns'); + var dnsPromises = require('dns/promises'); + var origDnsLookup = dns.lookup; + var origDnsPromisesLookup = dnsPromises.lookup; + + function proxyDnsResolve(hostname) { + return new Promise(function(resolve, reject) { + var req = http.request({ + host: proxyHost, + port: proxyPort, + path: '/__greywall_dns?host=' + encodeURIComponent(hostname), + method: 'GET', + agent: directAgent, + }, function(res) { + var data = ''; + res.on('data', function(chunk) { data += chunk; }); + res.on('end', function() { + try { + var parsed = JSON.parse(data); + if (parsed.error) { + var err = new Error(parsed.error); + err.code = 'ENOTFOUND'; + reject(err); + } else { + resolve(parsed.addresses || []); + } + } catch(e) { + reject(e); + } + }); + }); + req.on('error', reject); + req.end(); + }); + } + + dnsPromises.lookup = function(hostname, options) { + if (!shouldProxy(hostname)) { + return origDnsPromisesLookup.call(dnsPromises, hostname, options); + } + + return proxyDnsResolve(hostname).then(function(addresses) { + if (!addresses || addresses.length === 0) { + var err = new Error('getaddrinfo ENOTFOUND ' + hostname); + err.code = 'ENOTFOUND'; + throw err; + } + + var opts = (typeof options === 'object' && options !== null) ? options : {}; + var family = typeof options === 'number' ? options : (opts.family || 0); + + var filtered = addresses; + if (family === 4 || family === 6) { + filtered = addresses.filter(function(a) { return a.family === family; }); + if (filtered.length === 0) filtered = addresses; + } + + if (opts.all) { + return filtered; + } + + return filtered[0]; + }); + }; + + dns.lookup = function(hostname, options, callback) { + if (typeof options === 'function') { + callback = options; + options = {}; + } + + if (!shouldProxy(hostname)) { + return origDnsLookup.call(dns, hostname, options, callback); + } + + dnsPromises.lookup(hostname, options).then(function(result) { + if (Array.isArray(result)) { + callback(null, result); + } else { + callback(null, result.address, result.family); + } + }, function(err) { + callback(err); + }); + }; + + } catch (e) {} +})(); +` + +// nodeProxyBootstrapPath is the path where the bootstrap script is written. +const nodeProxyBootstrapPath = "/tmp/greywall/node-proxy-bootstrap.js" + +// WriteNodeProxyBootstrap writes the Node.js proxy bootstrap script to disk. +// Returns the path to the script, or an error if it couldn't be written. +func WriteNodeProxyBootstrap() (string, error) { + dir := filepath.Dir(nodeProxyBootstrapPath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + if err := os.WriteFile(nodeProxyBootstrapPath, []byte(nodeProxyBootstrapJS), 0o644); err != nil { + return "", err + } + return nodeProxyBootstrapPath, nil +} + // EncodeSandboxedCommand encodes a command for sandbox monitoring. func EncodeSandboxedCommand(command string) string { if len(command) > 100 { diff --git a/internal/sandbox/utils_test.go b/internal/sandbox/utils_test.go index a0fbd91..0b8b5c3 100644 --- a/internal/sandbox/utils_test.go +++ b/internal/sandbox/utils_test.go @@ -199,6 +199,72 @@ func TestGenerateProxyEnvVars(t *testing.T) { } } +func TestGenerateHTTPProxyEnvVars(t *testing.T) { + tests := []struct { + name string + httpProxyURL string + wantEnvs []string + dontWant []string + }{ + { + name: "no proxy", + httpProxyURL: "", + wantEnvs: []string{ + "GREYWALL_SANDBOX=1", + "TMPDIR=/tmp/greywall", + }, + dontWant: []string{ + "HTTP_PROXY=", + "HTTPS_PROXY=", + }, + }, + { + name: "http proxy", + httpProxyURL: "http://127.0.0.1:12345", + wantEnvs: []string{ + "GREYWALL_SANDBOX=1", + "HTTP_PROXY=http://127.0.0.1:12345", + "HTTPS_PROXY=http://127.0.0.1:12345", + "http_proxy=http://127.0.0.1:12345", + "https_proxy=http://127.0.0.1:12345", + "NO_PROXY=", + "no_proxy=", + }, + dontWant: []string{ + "ALL_PROXY=", + "all_proxy=", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateHTTPProxyEnvVars(tt.httpProxyURL) + + for _, want := range tt.wantEnvs { + found := false + for _, env := range got { + if strings.HasPrefix(env, want) || env == want { + found = true + break + } + } + if !found { + t.Errorf("GenerateHTTPProxyEnvVars(%q) missing %q", tt.httpProxyURL, want) + } + } + + for _, dontWant := range tt.dontWant { + for _, env := range got { + if strings.HasPrefix(env, dontWant) { + t.Errorf("GenerateHTTPProxyEnvVars(%q) should not contain %q, got %q", tt.httpProxyURL, dontWant, env) + } + } + } + }) + } +} + func TestEncodeSandboxedCommand(t *testing.T) { tests := []struct { name string