diff --git a/cmd/greywall/daemon.go b/cmd/greywall/daemon.go new file mode 100644 index 0000000..043ed75 --- /dev/null +++ b/cmd/greywall/daemon.go @@ -0,0 +1,229 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "os/signal" + "path/filepath" + "runtime" + "strings" + "syscall" + + "github.com/spf13/cobra" + + "gitea.app.monadical.io/monadical/greywall/internal/daemon" + "gitea.app.monadical.io/monadical/greywall/internal/sandbox" +) + +// newDaemonCmd creates the daemon subcommand tree: +// +// greywall daemon +// install - Install the LaunchDaemon (requires root) +// uninstall - Uninstall the LaunchDaemon (requires root) +// run - Run the daemon (called by LaunchDaemon plist) +// status - Show daemon status +func newDaemonCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "daemon", + Short: "Manage the greywall background daemon", + Long: `Manage the greywall LaunchDaemon for transparent network sandboxing on macOS. + +The daemon runs as a system service and manages the tun2socks tunnel, DNS relay, +and pf rules that enable transparent proxy routing for sandboxed processes. + +Commands: + sudo greywall daemon install Install and start the daemon + sudo greywall daemon uninstall Stop and remove the daemon + greywall daemon status Check daemon status + greywall daemon run Run the daemon (used by LaunchDaemon)`, + } + + cmd.AddCommand( + newDaemonInstallCmd(), + newDaemonUninstallCmd(), + newDaemonRunCmd(), + newDaemonStatusCmd(), + ) + + return cmd +} + +// newDaemonInstallCmd creates the "daemon install" subcommand. +func newDaemonInstallCmd() *cobra.Command { + return &cobra.Command{ + Use: "install", + Short: "Install the greywall LaunchDaemon (requires root)", + Long: `Install greywall as a macOS LaunchDaemon. This command: + 1. Creates a system user (_greywall) for sandboxed process isolation + 2. Copies the greywall binary to /usr/local/bin/greywall + 3. Extracts and installs the tun2socks binary + 4. Installs a LaunchDaemon plist for automatic startup + 5. Loads and starts the daemon + +Requires root privileges: sudo greywall daemon install`, + RunE: func(cmd *cobra.Command, args []string) error { + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to determine executable path: %w", err) + } + exePath, err = filepath.EvalSymlinks(exePath) + if err != nil { + return fmt.Errorf("failed to resolve executable path: %w", err) + } + + // Extract embedded tun2socks binary to a temp file. + tun2socksPath, err := sandbox.ExtractTun2Socks() + if err != nil { + return fmt.Errorf("failed to extract tun2socks: %w", err) + } + defer os.Remove(tun2socksPath) //nolint:errcheck // temp file cleanup + + if err := daemon.Install(exePath, tun2socksPath, debug); err != nil { + return err + } + + fmt.Println() + fmt.Println("To check status: greywall daemon status") + fmt.Println("To uninstall: sudo greywall daemon uninstall") + return nil + }, + } +} + +// newDaemonUninstallCmd creates the "daemon uninstall" subcommand. +func newDaemonUninstallCmd() *cobra.Command { + var force bool + + cmd := &cobra.Command{ + Use: "uninstall", + Short: "Uninstall the greywall LaunchDaemon (requires root)", + Long: `Uninstall the greywall LaunchDaemon. This command: + 1. Stops and unloads the daemon + 2. Removes the LaunchDaemon plist + 3. Removes installed files + 4. Removes the _greywall system user and group + +Requires root privileges: sudo greywall daemon uninstall`, + RunE: func(cmd *cobra.Command, args []string) error { + if !force { + fmt.Println("The following will be removed:") + fmt.Printf(" - LaunchDaemon plist: %s\n", daemon.LaunchDaemonPlistPath) + fmt.Printf(" - Binary: %s\n", daemon.InstallBinaryPath) + fmt.Printf(" - Lib directory: %s\n", daemon.InstallLibDir) + fmt.Printf(" - Socket: %s\n", daemon.DefaultSocketPath) + fmt.Printf(" - Sudoers file: %s\n", daemon.SudoersFilePath) + fmt.Printf(" - System user/group: %s\n", daemon.SandboxUserName) + fmt.Println() + fmt.Print("Proceed with uninstall? [y/N] ") + + reader := bufio.NewReader(os.Stdin) + answer, _ := reader.ReadString('\n') + answer = strings.TrimSpace(strings.ToLower(answer)) + if answer != "y" && answer != "yes" { + fmt.Println("Uninstall cancelled.") + return nil + } + } + + if err := daemon.Uninstall(debug); err != nil { + return err + } + + fmt.Println() + fmt.Println("The greywall daemon has been uninstalled.") + return nil + }, + } + + cmd.Flags().BoolVarP(&force, "force", "f", false, "Skip confirmation prompt") + return cmd +} + +// newDaemonRunCmd creates the "daemon run" subcommand. This is invoked by +// the LaunchDaemon plist and should not normally be called manually. +func newDaemonRunCmd() *cobra.Command { + return &cobra.Command{ + Use: "run", + Short: "Run the daemon process (called by LaunchDaemon)", + Hidden: true, // Not intended for direct user invocation. + RunE: runDaemon, + } +} + +// newDaemonStatusCmd creates the "daemon status" subcommand. +func newDaemonStatusCmd() *cobra.Command { + return &cobra.Command{ + Use: "status", + Short: "Show the daemon status", + Long: `Check whether the greywall daemon is installed and running. Does not require root.`, + RunE: func(cmd *cobra.Command, args []string) error { + installed := daemon.IsInstalled() + running := daemon.IsRunning() + + fmt.Printf("Greywall daemon status:\n") + fmt.Printf(" Installed: %s\n", boolStatus(installed)) + fmt.Printf(" Running: %s\n", boolStatus(running)) + fmt.Printf(" Plist: %s\n", daemon.LaunchDaemonPlistPath) + fmt.Printf(" Binary: %s\n", daemon.InstallBinaryPath) + fmt.Printf(" User: %s\n", daemon.SandboxUserName) + fmt.Printf(" Group: %s (pf routing)\n", daemon.SandboxGroupName) + fmt.Printf(" Sudoers: %s\n", daemon.SudoersFilePath) + fmt.Printf(" Socket: %s\n", daemon.DefaultSocketPath) + + if !installed { + fmt.Println() + fmt.Println("The daemon is not installed. Run: sudo greywall daemon install") + } else if !running { + fmt.Println() + fmt.Println("The daemon is installed but not running.") + fmt.Printf("Check logs: cat /var/log/greywall.log\n") + fmt.Printf("Start it: sudo launchctl load %s\n", daemon.LaunchDaemonPlistPath) + } + + return nil + }, + } +} + +// runDaemon is the main entry point for the daemon process. It starts the +// Unix socket server and blocks until a termination signal is received. +// CLI clients connect to the server to request sessions (which create +// utun tunnels, DNS relays, and pf rules on demand). +func runDaemon(cmd *cobra.Command, args []string) error { + tun2socksPath := filepath.Join(daemon.InstallLibDir, "tun2socks-darwin-"+runtime.GOARCH) + if _, err := os.Stat(tun2socksPath); err != nil { + return fmt.Errorf("tun2socks binary not found at %s (run 'sudo greywall daemon install' first)", tun2socksPath) + } + + daemon.Logf("Starting daemon (tun2socks=%s, socket=%s)", tun2socksPath, daemon.DefaultSocketPath) + + srv := daemon.NewServer(daemon.DefaultSocketPath, tun2socksPath, debug) + if err := srv.Start(); err != nil { + return fmt.Errorf("failed to start daemon server: %w", err) + } + + daemon.Logf("Daemon started, listening on %s", daemon.DefaultSocketPath) + + // Wait for termination signal. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigCh + daemon.Logf("Received signal %s, shutting down", sig) + + if err := srv.Stop(); err != nil { + daemon.Logf("Shutdown error: %v", err) + return err + } + + daemon.Logf("Daemon stopped") + return nil +} + +// boolStatus returns a human-readable string for a boolean status value. +func boolStatus(b bool) string { + if b { + return "yes" + } + return "no" +} diff --git a/internal/daemon/client.go b/internal/daemon/client.go new file mode 100644 index 0000000..fa12c78 --- /dev/null +++ b/internal/daemon/client.go @@ -0,0 +1,144 @@ +package daemon + +import ( + "encoding/json" + "fmt" + "net" + "os" + "time" +) + +const ( + // clientDialTimeout is the maximum time to wait when connecting to the daemon. + clientDialTimeout = 5 * time.Second + + // clientReadTimeout is the maximum time to wait for a response from the daemon. + clientReadTimeout = 30 * time.Second +) + +// Client communicates with the greywall daemon over a Unix socket using +// newline-delimited JSON. +type Client struct { + socketPath string + debug bool +} + +// NewClient creates a new daemon client that connects to the given Unix socket path. +func NewClient(socketPath string, debug bool) *Client { + return &Client{ + socketPath: socketPath, + debug: debug, + } +} + +// CreateSession asks the daemon to create a new sandbox session with the given +// proxy URL and optional DNS address. Returns the session info on success. +func (c *Client) CreateSession(proxyURL, dnsAddr string) (*Response, error) { + req := Request{ + Action: "create_session", + ProxyURL: proxyURL, + DNSAddr: dnsAddr, + } + + resp, err := c.sendRequest(req) + if err != nil { + return nil, fmt.Errorf("create session request failed: %w", err) + } + + if !resp.OK { + return resp, fmt.Errorf("create session failed: %s", resp.Error) + } + + return resp, nil +} + +// DestroySession asks the daemon to tear down the session with the given ID. +func (c *Client) DestroySession(sessionID string) error { + req := Request{ + Action: "destroy_session", + SessionID: sessionID, + } + + resp, err := c.sendRequest(req) + if err != nil { + return fmt.Errorf("destroy session request failed: %w", err) + } + + if !resp.OK { + return fmt.Errorf("destroy session failed: %s", resp.Error) + } + + return nil +} + +// Status queries the daemon for its current status. +func (c *Client) Status() (*Response, error) { + req := Request{ + Action: "status", + } + + resp, err := c.sendRequest(req) + if err != nil { + return nil, fmt.Errorf("status request failed: %w", err) + } + + if !resp.OK { + return resp, fmt.Errorf("status request failed: %s", resp.Error) + } + + return resp, nil +} + +// IsRunning checks whether the daemon is reachable by attempting to connect +// to the Unix socket. Returns true if the connection succeeds. +func (c *Client) IsRunning() bool { + conn, err := net.DialTimeout("unix", c.socketPath, clientDialTimeout) + if err != nil { + return false + } + _ = conn.Close() + return true +} + +// sendRequest connects to the daemon Unix socket, sends a JSON-encoded request, +// and reads back a JSON-encoded response. +func (c *Client) sendRequest(req Request) (*Response, error) { + c.logDebug("Connecting to daemon at %s", c.socketPath) + + conn, err := net.DialTimeout("unix", c.socketPath, clientDialTimeout) + if err != nil { + return nil, fmt.Errorf("failed to connect to daemon at %s: %w", c.socketPath, err) + } + defer conn.Close() //nolint:errcheck // best-effort close on request completion + + // Set a read deadline for the response. + if err := conn.SetReadDeadline(time.Now().Add(clientReadTimeout)); err != nil { + return nil, fmt.Errorf("failed to set read deadline: %w", err) + } + + // Send the request as newline-delimited JSON. + encoder := json.NewEncoder(conn) + if err := encoder.Encode(req); err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + c.logDebug("Sent request: action=%s", req.Action) + + // Read the response. + decoder := json.NewDecoder(conn) + var resp Response + if err := decoder.Decode(&resp); err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + c.logDebug("Received response: ok=%v", resp.OK) + + return &resp, nil +} + +// logDebug writes a debug message to stderr with the [greywall:daemon] prefix. +func (c *Client) logDebug(format string, args ...interface{}) { + if c.debug { + fmt.Fprintf(os.Stderr, "[greywall:daemon] "+format+"\n", args...) + } +} diff --git a/internal/daemon/dns.go b/internal/daemon/dns.go new file mode 100644 index 0000000..646ea9f --- /dev/null +++ b/internal/daemon/dns.go @@ -0,0 +1,186 @@ +//go:build darwin || linux + +package daemon + +import ( + "fmt" + "net" + "os" + "sync" + "time" +) + +const ( + // maxDNSPacketSize is the maximum UDP packet size we accept. + // DNS can theoretically be up to 65535 bytes, but practically much smaller. + maxDNSPacketSize = 4096 + + // upstreamTimeout is the time we wait for a response from the upstream DNS server. + upstreamTimeout = 5 * time.Second +) + +// DNSRelay is a UDP DNS relay that forwards DNS queries from sandboxed processes +// to a configured upstream DNS server. It operates as a simple packet relay without +// parsing DNS protocol contents. +type DNSRelay struct { + udpConn *net.UDPConn + targetAddr string // upstream DNS server address (host:port) + listenAddr string // address we're listening on + wg sync.WaitGroup + done chan struct{} + debug bool +} + +// NewDNSRelay creates a new DNS relay that listens on listenAddr and forwards +// queries to dnsAddr. The listenAddr will typically be "127.0.0.2:53" (loopback alias). +// The dnsAddr must be in "host:port" format (e.g. "1.1.1.1:53"). +func NewDNSRelay(listenAddr, dnsAddr string, debug bool) (*DNSRelay, error) { + // Validate the upstream DNS address is parseable as host:port. + targetHost, targetPort, err := net.SplitHostPort(dnsAddr) + if err != nil { + return nil, fmt.Errorf("invalid DNS address %q: %w", dnsAddr, err) + } + if targetHost == "" { + return nil, fmt.Errorf("invalid DNS address %q: empty host", dnsAddr) + } + if targetPort == "" { + return nil, fmt.Errorf("invalid DNS address %q: empty port", dnsAddr) + } + + // Resolve and bind the listen address. + udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return nil, fmt.Errorf("failed to resolve listen address %q: %w", listenAddr, err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, fmt.Errorf("failed to bind UDP socket on %q: %w", listenAddr, err) + } + + return &DNSRelay{ + udpConn: conn, + targetAddr: dnsAddr, + listenAddr: conn.LocalAddr().String(), + done: make(chan struct{}), + debug: debug, + }, nil +} + +// ListenAddr returns the actual address the relay is listening on. +// This is useful when port 0 was used to get an ephemeral port. +func (d *DNSRelay) ListenAddr() string { + return d.listenAddr +} + +// Start begins the DNS relay loop. It reads incoming UDP packets from the +// listening socket and spawns a goroutine per query to forward it to the +// upstream DNS server and relay the response back. +func (d *DNSRelay) Start() error { + if d.debug { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Listening on %s, forwarding to %s\n", d.listenAddr, d.targetAddr) + } + + d.wg.Add(1) + go d.readLoop() + + return nil +} + +// Stop shuts down the DNS relay. It signals the read loop to stop, closes the +// listening socket, and waits for all in-flight queries to complete. +func (d *DNSRelay) Stop() { + close(d.done) + _ = d.udpConn.Close() + d.wg.Wait() + + if d.debug { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Stopped\n") + } +} + +// readLoop is the main loop that reads incoming DNS queries from the listening socket. +func (d *DNSRelay) readLoop() { + defer d.wg.Done() + + buf := make([]byte, maxDNSPacketSize) + for { + n, clientAddr, err := d.udpConn.ReadFromUDP(buf) + if err != nil { + select { + case <-d.done: + // Shutting down, expected error from closed socket. + return + default: + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Read error: %v\n", err) + continue + } + } + + if n == 0 { + continue + } + + // Copy the packet data so the buffer can be reused immediately. + query := make([]byte, n) + copy(query, buf[:n]) + + d.wg.Add(1) + go d.handleQuery(query, clientAddr) + } +} + +// handleQuery forwards a single DNS query to the upstream server and relays +// the response back to the original client. It creates a dedicated UDP connection +// to the upstream server to avoid multiplexing complexity. +func (d *DNSRelay) handleQuery(query []byte, clientAddr *net.UDPAddr) { + defer d.wg.Done() + + if d.debug { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Query from %s (%d bytes)\n", clientAddr, len(query)) + } + + // Create a dedicated UDP connection to the upstream DNS server. + upstreamConn, err := net.Dial("udp", d.targetAddr) + if err != nil { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to connect to upstream %s: %v\n", d.targetAddr, err) + return + } + defer upstreamConn.Close() //nolint:errcheck // best-effort cleanup of per-query UDP connection + + // Send the query to the upstream server. + if _, err := upstreamConn.Write(query); err != nil { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to send query to upstream: %v\n", err) + return + } + + // Wait for the response with a timeout. + if err := upstreamConn.SetReadDeadline(time.Now().Add(upstreamTimeout)); err != nil { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to set read deadline: %v\n", err) + return + } + + resp := make([]byte, maxDNSPacketSize) + n, err := upstreamConn.Read(resp) + if err != nil { + if d.debug { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Upstream response error: %v\n", err) + } + return + } + + // Send the response back to the original client. + if _, err := d.udpConn.WriteToUDP(resp[:n], clientAddr); err != nil { + // The listening socket may have been closed during shutdown. + select { + case <-d.done: + return + default: + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to send response to %s: %v\n", clientAddr, err) + } + } + + if d.debug { + fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Response to %s (%d bytes)\n", clientAddr, n) + } +} diff --git a/internal/daemon/dns_test.go b/internal/daemon/dns_test.go new file mode 100644 index 0000000..d52da7b --- /dev/null +++ b/internal/daemon/dns_test.go @@ -0,0 +1,296 @@ +//go:build darwin || linux + +package daemon + +import ( + "bytes" + "net" + "sync" + "testing" + "time" +) + +// startMockDNSServer starts a UDP server that echoes back whatever it receives, +// prefixed with "RESP:" to distinguish responses from queries. +// Returns the server's address and a cleanup function. +func startMockDNSServer(t *testing.T) (string, func()) { + t.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to resolve address: %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("Failed to start mock DNS server: %v", err) + } + + done := make(chan struct{}) + go func() { + buf := make([]byte, maxDNSPacketSize) + for { + n, remoteAddr, err := conn.ReadFromUDP(buf) + if err != nil { + select { + case <-done: + return + default: + continue + } + } + // Echo back with "RESP:" prefix. + resp := append([]byte("RESP:"), buf[:n]...) + _, _ = conn.WriteToUDP(resp, remoteAddr) // best-effort in test + } + }() + + cleanup := func() { + close(done) + _ = conn.Close() + } + + return conn.LocalAddr().String(), cleanup +} + +// startSilentDNSServer starts a UDP server that accepts connections but never +// responds, simulating an upstream timeout. +func startSilentDNSServer(t *testing.T) (string, func()) { + t.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to resolve address: %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("Failed to start silent DNS server: %v", err) + } + + cleanup := func() { + _ = conn.Close() + } + + return conn.LocalAddr().String(), cleanup +} + +func TestDNSRelay_ForwardPacket(t *testing.T) { + // Start a mock upstream DNS server. + upstreamAddr, cleanupUpstream := startMockDNSServer(t) + defer cleanupUpstream() + + // Create and start the relay. + relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, true) + if err != nil { + t.Fatalf("Failed to create DNS relay: %v", err) + } + + if err := relay.Start(); err != nil { + t.Fatalf("Failed to start DNS relay: %v", err) + } + defer relay.Stop() + + // Send a query through the relay. + clientConn, err := net.Dial("udp", relay.ListenAddr()) + if err != nil { + t.Fatalf("Failed to connect to relay: %v", err) + } + defer clientConn.Close() //nolint:errcheck // test cleanup + + query := []byte("test-dns-query") + if _, err := clientConn.Write(query); err != nil { + t.Fatalf("Failed to send query: %v", err) + } + + // Read the response. + if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("Failed to set read deadline: %v", err) + } + + buf := make([]byte, maxDNSPacketSize) + n, err := clientConn.Read(buf) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + expected := append([]byte("RESP:"), query...) + if !bytes.Equal(buf[:n], expected) { + t.Errorf("Unexpected response: got %q, want %q", buf[:n], expected) + } +} + +func TestDNSRelay_UpstreamTimeout(t *testing.T) { + // Start a silent upstream server that never responds. + upstreamAddr, cleanupUpstream := startSilentDNSServer(t) + defer cleanupUpstream() + + // Create and start the relay. + relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, false) + if err != nil { + t.Fatalf("Failed to create DNS relay: %v", err) + } + + if err := relay.Start(); err != nil { + t.Fatalf("Failed to start DNS relay: %v", err) + } + defer relay.Stop() + + // Send a query through the relay. + clientConn, err := net.Dial("udp", relay.ListenAddr()) + if err != nil { + t.Fatalf("Failed to connect to relay: %v", err) + } + defer clientConn.Close() //nolint:errcheck // test cleanup + + query := []byte("test-dns-timeout") + if _, err := clientConn.Write(query); err != nil { + t.Fatalf("Failed to send query: %v", err) + } + + // The relay should not send back a response because upstream timed out. + // Set a short deadline on the client side; we expect no data. + if err := clientConn.SetReadDeadline(time.Now().Add(6 * time.Second)); err != nil { + t.Fatalf("Failed to set read deadline: %v", err) + } + + buf := make([]byte, maxDNSPacketSize) + _, err = clientConn.Read(buf) + if err == nil { + t.Fatal("Expected timeout error reading from relay, but got a response") + } + + // Verify it was a timeout error. + netErr, ok := err.(net.Error) + if !ok || !netErr.Timeout() { + t.Fatalf("Expected timeout error, got: %v", err) + } +} + +func TestDNSRelay_ConcurrentQueries(t *testing.T) { + // Start a mock upstream DNS server. + upstreamAddr, cleanupUpstream := startMockDNSServer(t) + defer cleanupUpstream() + + // Create and start the relay. + relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, false) + if err != nil { + t.Fatalf("Failed to create DNS relay: %v", err) + } + + if err := relay.Start(); err != nil { + t.Fatalf("Failed to start DNS relay: %v", err) + } + defer relay.Stop() + + const numQueries = 20 + var wg sync.WaitGroup + errors := make(chan error, numQueries) + + for i := range numQueries { + wg.Add(1) + go func(id int) { + defer wg.Done() + + clientConn, err := net.Dial("udp", relay.ListenAddr()) + if err != nil { + errors <- err + return + } + defer clientConn.Close() //nolint:errcheck // test cleanup + + query := []byte("concurrent-query-" + string(rune('A'+id))) //nolint:gosec // test uses small range 0-19, no overflow + if _, err := clientConn.Write(query); err != nil { + errors <- err + return + } + + if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + errors <- err + return + } + + buf := make([]byte, maxDNSPacketSize) + n, err := clientConn.Read(buf) + if err != nil { + errors <- err + return + } + + expected := append([]byte("RESP:"), query...) + if !bytes.Equal(buf[:n], expected) { + errors <- &unexpectedResponseError{got: buf[:n], want: expected} + } + }(i) + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("Concurrent query error: %v", err) + } +} + +func TestDNSRelay_ListenAddr(t *testing.T) { + // Use port 0 to get an ephemeral port. + relay, err := NewDNSRelay("127.0.0.1:0", "1.1.1.1:53", false) + if err != nil { + t.Fatalf("Failed to create DNS relay: %v", err) + } + defer relay.Stop() + + addr := relay.ListenAddr() + if addr == "" { + t.Fatal("ListenAddr returned empty string") + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("ListenAddr returned invalid address %q: %v", addr, err) + } + if host != "127.0.0.1" { + t.Errorf("Expected host 127.0.0.1, got %q", host) + } + if port == "0" { + t.Error("Expected assigned port, got 0") + } +} + +func TestNewDNSRelay_InvalidDNSAddr(t *testing.T) { + tests := []struct { + name string + dnsAddr string + }{ + {"missing port", "1.1.1.1"}, + {"empty string", ""}, + {"empty host", ":53"}, + {"empty port", "1.1.1.1:"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewDNSRelay("127.0.0.1:0", tt.dnsAddr, false) + if err == nil { + t.Errorf("Expected error for DNS address %q, got nil", tt.dnsAddr) + } + }) + } +} + +func TestNewDNSRelay_InvalidListenAddr(t *testing.T) { + _, err := NewDNSRelay("invalid-addr", "1.1.1.1:53", false) + if err == nil { + t.Error("Expected error for invalid listen address, got nil") + } +} + +// unexpectedResponseError is used to report mismatched responses in concurrent tests. +type unexpectedResponseError struct { + got []byte + want []byte +} + +func (e *unexpectedResponseError) Error() string { + return "unexpected response: got " + string(e.got) + ", want " + string(e.want) +} diff --git a/internal/daemon/launchd.go b/internal/daemon/launchd.go new file mode 100644 index 0000000..6d25386 --- /dev/null +++ b/internal/daemon/launchd.go @@ -0,0 +1,554 @@ +//go:build darwin + +package daemon + +import ( + "fmt" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" +) + +const ( + LaunchDaemonLabel = "co.greyhaven.greywall" + LaunchDaemonPlistPath = "/Library/LaunchDaemons/co.greyhaven.greywall.plist" + InstallBinaryPath = "/usr/local/bin/greywall" + InstallLibDir = "/usr/local/lib/greywall" + SandboxUserName = "_greywall" + SandboxUserUID = "399" // System user range on macOS + SandboxGroupName = "_greywall" // Group used for pf routing (same name as user) + SudoersFilePath = "/etc/sudoers.d/greywall" + DefaultSocketPath = "/var/run/greywall.sock" +) + +// Install performs the full LaunchDaemon installation flow: +// 1. Verify running as root +// 2. Create system user _greywall +// 3. Create /usr/local/lib/greywall/ directory and copy tun2socks +// 4. Copy the current binary to /usr/local/bin/greywall +// 5. Generate and write the LaunchDaemon plist +// 6. Set proper permissions, load the daemon, and verify it starts +func Install(currentBinaryPath, tun2socksPath string, debug bool) error { + if os.Getuid() != 0 { + return fmt.Errorf("daemon install must be run as root (use sudo)") + } + + // Step 1: Create system user and group. + if err := createSandboxUser(debug); err != nil { + return fmt.Errorf("failed to create sandbox user: %w", err) + } + + // Step 1b: Install sudoers rule for group-based sandbox-exec. + if err := installSudoersRule(debug); err != nil { + return fmt.Errorf("failed to install sudoers rule: %w", err) + } + + // Step 1c: Add invoking user to _greywall group. + addInvokingUserToGroup(debug) + + // Step 2: Create lib directory and copy tun2socks. + logDebug(debug, "Creating directory %s", InstallLibDir) + if err := os.MkdirAll(InstallLibDir, 0o755); err != nil { //nolint:gosec // system lib directory needs 0755 for daemon access + return fmt.Errorf("failed to create %s: %w", InstallLibDir, err) + } + tun2socksDst := filepath.Join(InstallLibDir, "tun2socks-darwin-"+runtime.GOARCH) + logDebug(debug, "Copying tun2socks to %s", tun2socksDst) + if err := copyFile(tun2socksPath, tun2socksDst, 0o755); err != nil { + return fmt.Errorf("failed to install tun2socks: %w", err) + } + + // Step 3: Copy binary to install path. + if err := os.MkdirAll(filepath.Dir(InstallBinaryPath), 0o755); err != nil { //nolint:gosec // /usr/local/bin needs 0755 + return fmt.Errorf("failed to create %s: %w", filepath.Dir(InstallBinaryPath), err) + } + logDebug(debug, "Copying binary from %s to %s", currentBinaryPath, InstallBinaryPath) + if err := copyFile(currentBinaryPath, InstallBinaryPath, 0o755); err != nil { + return fmt.Errorf("failed to install binary: %w", err) + } + + // Step 4: Generate and write plist. + plist := generatePlist() + logDebug(debug, "Writing plist to %s", LaunchDaemonPlistPath) + if err := os.WriteFile(LaunchDaemonPlistPath, []byte(plist), 0o644); err != nil { //nolint:gosec // LaunchDaemon plist requires 0644 per macOS convention + return fmt.Errorf("failed to write plist: %w", err) + } + + // Step 5: Set ownership to root:wheel. + logDebug(debug, "Setting ownership on %s to root:wheel", LaunchDaemonPlistPath) + if err := runCmd(debug, "chown", "root:wheel", LaunchDaemonPlistPath); err != nil { + return fmt.Errorf("failed to set plist ownership: %w", err) + } + + // Step 6: Load the daemon. + logDebug(debug, "Loading LaunchDaemon") + if err := runCmd(debug, "launchctl", "load", LaunchDaemonPlistPath); err != nil { + return fmt.Errorf("failed to load daemon: %w", err) + } + + // Step 7: Verify the daemon actually started. + running := false + for range 10 { + time.Sleep(500 * time.Millisecond) + if IsRunning() { + running = true + break + } + } + + Logf("Daemon installed successfully.") + Logf(" Plist: %s", LaunchDaemonPlistPath) + Logf(" Binary: %s", InstallBinaryPath) + Logf(" Tun2socks: %s", tun2socksDst) + actualUID := readDsclAttr(SandboxUserName, "UniqueID", true) + actualGID := readDsclAttr(SandboxGroupName, "PrimaryGroupID", false) + Logf(" User: %s (UID %s)", SandboxUserName, actualUID) + Logf(" Group: %s (GID %s, pf routing)", SandboxGroupName, actualGID) + Logf(" Sudoers: %s", SudoersFilePath) + Logf(" Log: /var/log/greywall.log") + + if !running { + Logf(" Status: NOT RUNNING (check /var/log/greywall.log)") + return fmt.Errorf("daemon was loaded but failed to start; check /var/log/greywall.log") + } + + Logf(" Status: running") + return nil +} + +// Uninstall performs the full LaunchDaemon uninstallation flow. It attempts +// every cleanup step even if individual steps fail, collecting errors along +// the way. +func Uninstall(debug bool) error { + if os.Getuid() != 0 { + return fmt.Errorf("daemon uninstall must be run as root (use sudo)") + } + + var errs []string + + // Step 1: Unload daemon (best effort). + logDebug(debug, "Unloading LaunchDaemon") + if err := runCmd(debug, "launchctl", "unload", LaunchDaemonPlistPath); err != nil { + errs = append(errs, fmt.Sprintf("unload daemon: %v", err)) + } + + // Step 2: Remove plist file. + logDebug(debug, "Removing plist %s", LaunchDaemonPlistPath) + if err := os.Remove(LaunchDaemonPlistPath); err != nil && !os.IsNotExist(err) { + errs = append(errs, fmt.Sprintf("remove plist: %v", err)) + } + + // Step 3: Remove lib directory. + logDebug(debug, "Removing directory %s", InstallLibDir) + if err := os.RemoveAll(InstallLibDir); err != nil { + errs = append(errs, fmt.Sprintf("remove lib dir: %v", err)) + } + + // Step 4: Remove installed binary, but only if it differs from the + // currently running executable. + currentExe, exeErr := os.Executable() + if exeErr != nil { + currentExe = "" + } + resolvedCurrent, _ := filepath.EvalSymlinks(currentExe) + resolvedInstall, _ := filepath.EvalSymlinks(InstallBinaryPath) + if resolvedCurrent != resolvedInstall { + logDebug(debug, "Removing binary %s", InstallBinaryPath) + if err := os.Remove(InstallBinaryPath); err != nil && !os.IsNotExist(err) { + errs = append(errs, fmt.Sprintf("remove binary: %v", err)) + } + } else { + logDebug(debug, "Skipping binary removal (currently running from %s)", InstallBinaryPath) + } + + // Step 5: Remove system user and group. + if err := removeSandboxUser(debug); err != nil { + errs = append(errs, fmt.Sprintf("remove sandbox user: %v", err)) + } + + // Step 6: Remove socket file if it exists. + logDebug(debug, "Removing socket %s", DefaultSocketPath) + if err := os.Remove(DefaultSocketPath); err != nil && !os.IsNotExist(err) { + errs = append(errs, fmt.Sprintf("remove socket: %v", err)) + } + + // Step 6b: Remove sudoers file. + logDebug(debug, "Removing sudoers file %s", SudoersFilePath) + if err := os.Remove(SudoersFilePath); err != nil && !os.IsNotExist(err) { + errs = append(errs, fmt.Sprintf("remove sudoers file: %v", err)) + } + + // Step 7: Remove pf anchor lines from /etc/pf.conf. + if err := removeAnchorFromPFConf(debug); err != nil { + errs = append(errs, fmt.Sprintf("remove pf anchor: %v", err)) + } + + if len(errs) > 0 { + Logf("Uninstall completed with warnings:") + for _, e := range errs { + Logf(" - %s", e) + } + return nil // partial cleanup is not a fatal error + } + + Logf("Daemon uninstalled successfully.") + return nil +} + +// generatePlist returns the LaunchDaemon plist XML content. +func generatePlist() string { + return ` + + + + Label + ` + LaunchDaemonLabel + ` + ProgramArguments + + ` + InstallBinaryPath + ` + daemon + run + + RunAtLoad + KeepAlive + StandardOutPath + /var/log/greywall.log + StandardErrorPath + /var/log/greywall.log + + +` +} + +// IsInstalled returns true if the LaunchDaemon plist file exists. +func IsInstalled() bool { + _, err := os.Stat(LaunchDaemonPlistPath) + return err == nil +} + +// IsRunning returns true if the daemon is currently running. It first tries +// connecting to the Unix socket (works without root), then falls back to +// launchctl print which can inspect the system domain without root. +func IsRunning() bool { + // Primary check: try to connect to the daemon socket. This proves the + // daemon is actually running and accepting connections. + conn, err := net.DialTimeout("unix", DefaultSocketPath, 2*time.Second) + if err == nil { + _ = conn.Close() + return true + } + + // Fallback: launchctl print system/