feat: switch macOS daemon from user-based to group-based pf routing

Sandboxed commands previously ran as `sudo -u _greywall`, breaking user
identity (home dir, SSH keys, git config). Now uses `sudo -u #<uid> -g
_greywall` so the process keeps the real user's identity while pf
matches
on EGID for traffic routing.

Key changes:
- pf rules use `group <GID>` instead of `user _greywall`
- GID resolved dynamically at daemon startup (not hardcoded, since macOS
  system groups like com.apple.access_ssh may claim preferred IDs)
- Sudoers rule installed at /etc/sudoers.d/greywall (validated with
visudo)
- Invoking user added to _greywall group via dscl (not dseditgroup,
which
  clobbers group attributes)
- tun2socks device discovery scans both stdout and stderr (fixes 10s
  timeout caused by STACK message going to stdout)
- Always-on daemon logging for session create/destroy events
This commit is contained in:
2026-02-25 19:20:01 -06:00
parent 4ea4592d75
commit cfe29d2c0b
15 changed files with 3866 additions and 18 deletions

229
cmd/greywall/daemon.go Normal file
View File

@@ -0,0 +1,229 @@
package main
import (
"bufio"
"fmt"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
"github.com/spf13/cobra"
"gitea.app.monadical.io/monadical/greywall/internal/daemon"
"gitea.app.monadical.io/monadical/greywall/internal/sandbox"
)
// newDaemonCmd creates the daemon subcommand tree:
//
// greywall daemon
// install - Install the LaunchDaemon (requires root)
// uninstall - Uninstall the LaunchDaemon (requires root)
// run - Run the daemon (called by LaunchDaemon plist)
// status - Show daemon status
func newDaemonCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "daemon",
Short: "Manage the greywall background daemon",
Long: `Manage the greywall LaunchDaemon for transparent network sandboxing on macOS.
The daemon runs as a system service and manages the tun2socks tunnel, DNS relay,
and pf rules that enable transparent proxy routing for sandboxed processes.
Commands:
sudo greywall daemon install Install and start the daemon
sudo greywall daemon uninstall Stop and remove the daemon
greywall daemon status Check daemon status
greywall daemon run Run the daemon (used by LaunchDaemon)`,
}
cmd.AddCommand(
newDaemonInstallCmd(),
newDaemonUninstallCmd(),
newDaemonRunCmd(),
newDaemonStatusCmd(),
)
return cmd
}
// newDaemonInstallCmd creates the "daemon install" subcommand.
func newDaemonInstallCmd() *cobra.Command {
return &cobra.Command{
Use: "install",
Short: "Install the greywall LaunchDaemon (requires root)",
Long: `Install greywall as a macOS LaunchDaemon. This command:
1. Creates a system user (_greywall) for sandboxed process isolation
2. Copies the greywall binary to /usr/local/bin/greywall
3. Extracts and installs the tun2socks binary
4. Installs a LaunchDaemon plist for automatic startup
5. Loads and starts the daemon
Requires root privileges: sudo greywall daemon install`,
RunE: func(cmd *cobra.Command, args []string) error {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to determine executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve executable path: %w", err)
}
// Extract embedded tun2socks binary to a temp file.
tun2socksPath, err := sandbox.ExtractTun2Socks()
if err != nil {
return fmt.Errorf("failed to extract tun2socks: %w", err)
}
defer os.Remove(tun2socksPath) //nolint:errcheck // temp file cleanup
if err := daemon.Install(exePath, tun2socksPath, debug); err != nil {
return err
}
fmt.Println()
fmt.Println("To check status: greywall daemon status")
fmt.Println("To uninstall: sudo greywall daemon uninstall")
return nil
},
}
}
// newDaemonUninstallCmd creates the "daemon uninstall" subcommand.
func newDaemonUninstallCmd() *cobra.Command {
var force bool
cmd := &cobra.Command{
Use: "uninstall",
Short: "Uninstall the greywall LaunchDaemon (requires root)",
Long: `Uninstall the greywall LaunchDaemon. This command:
1. Stops and unloads the daemon
2. Removes the LaunchDaemon plist
3. Removes installed files
4. Removes the _greywall system user and group
Requires root privileges: sudo greywall daemon uninstall`,
RunE: func(cmd *cobra.Command, args []string) error {
if !force {
fmt.Println("The following will be removed:")
fmt.Printf(" - LaunchDaemon plist: %s\n", daemon.LaunchDaemonPlistPath)
fmt.Printf(" - Binary: %s\n", daemon.InstallBinaryPath)
fmt.Printf(" - Lib directory: %s\n", daemon.InstallLibDir)
fmt.Printf(" - Socket: %s\n", daemon.DefaultSocketPath)
fmt.Printf(" - Sudoers file: %s\n", daemon.SudoersFilePath)
fmt.Printf(" - System user/group: %s\n", daemon.SandboxUserName)
fmt.Println()
fmt.Print("Proceed with uninstall? [y/N] ")
reader := bufio.NewReader(os.Stdin)
answer, _ := reader.ReadString('\n')
answer = strings.TrimSpace(strings.ToLower(answer))
if answer != "y" && answer != "yes" {
fmt.Println("Uninstall cancelled.")
return nil
}
}
if err := daemon.Uninstall(debug); err != nil {
return err
}
fmt.Println()
fmt.Println("The greywall daemon has been uninstalled.")
return nil
},
}
cmd.Flags().BoolVarP(&force, "force", "f", false, "Skip confirmation prompt")
return cmd
}
// newDaemonRunCmd creates the "daemon run" subcommand. This is invoked by
// the LaunchDaemon plist and should not normally be called manually.
func newDaemonRunCmd() *cobra.Command {
return &cobra.Command{
Use: "run",
Short: "Run the daemon process (called by LaunchDaemon)",
Hidden: true, // Not intended for direct user invocation.
RunE: runDaemon,
}
}
// newDaemonStatusCmd creates the "daemon status" subcommand.
func newDaemonStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status",
Short: "Show the daemon status",
Long: `Check whether the greywall daemon is installed and running. Does not require root.`,
RunE: func(cmd *cobra.Command, args []string) error {
installed := daemon.IsInstalled()
running := daemon.IsRunning()
fmt.Printf("Greywall daemon status:\n")
fmt.Printf(" Installed: %s\n", boolStatus(installed))
fmt.Printf(" Running: %s\n", boolStatus(running))
fmt.Printf(" Plist: %s\n", daemon.LaunchDaemonPlistPath)
fmt.Printf(" Binary: %s\n", daemon.InstallBinaryPath)
fmt.Printf(" User: %s\n", daemon.SandboxUserName)
fmt.Printf(" Group: %s (pf routing)\n", daemon.SandboxGroupName)
fmt.Printf(" Sudoers: %s\n", daemon.SudoersFilePath)
fmt.Printf(" Socket: %s\n", daemon.DefaultSocketPath)
if !installed {
fmt.Println()
fmt.Println("The daemon is not installed. Run: sudo greywall daemon install")
} else if !running {
fmt.Println()
fmt.Println("The daemon is installed but not running.")
fmt.Printf("Check logs: cat /var/log/greywall.log\n")
fmt.Printf("Start it: sudo launchctl load %s\n", daemon.LaunchDaemonPlistPath)
}
return nil
},
}
}
// runDaemon is the main entry point for the daemon process. It starts the
// Unix socket server and blocks until a termination signal is received.
// CLI clients connect to the server to request sessions (which create
// utun tunnels, DNS relays, and pf rules on demand).
func runDaemon(cmd *cobra.Command, args []string) error {
tun2socksPath := filepath.Join(daemon.InstallLibDir, "tun2socks-darwin-"+runtime.GOARCH)
if _, err := os.Stat(tun2socksPath); err != nil {
return fmt.Errorf("tun2socks binary not found at %s (run 'sudo greywall daemon install' first)", tun2socksPath)
}
daemon.Logf("Starting daemon (tun2socks=%s, socket=%s)", tun2socksPath, daemon.DefaultSocketPath)
srv := daemon.NewServer(daemon.DefaultSocketPath, tun2socksPath, debug)
if err := srv.Start(); err != nil {
return fmt.Errorf("failed to start daemon server: %w", err)
}
daemon.Logf("Daemon started, listening on %s", daemon.DefaultSocketPath)
// Wait for termination signal.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigCh
daemon.Logf("Received signal %s, shutting down", sig)
if err := srv.Stop(); err != nil {
daemon.Logf("Shutdown error: %v", err)
return err
}
daemon.Logf("Daemon stopped")
return nil
}
// boolStatus returns a human-readable string for a boolean status value.
func boolStatus(b bool) string {
if b {
return "yes"
}
return "no"
}

144
internal/daemon/client.go Normal file
View File

@@ -0,0 +1,144 @@
package daemon
import (
"encoding/json"
"fmt"
"net"
"os"
"time"
)
const (
// clientDialTimeout is the maximum time to wait when connecting to the daemon.
clientDialTimeout = 5 * time.Second
// clientReadTimeout is the maximum time to wait for a response from the daemon.
clientReadTimeout = 30 * time.Second
)
// Client communicates with the greywall daemon over a Unix socket using
// newline-delimited JSON.
type Client struct {
socketPath string
debug bool
}
// NewClient creates a new daemon client that connects to the given Unix socket path.
func NewClient(socketPath string, debug bool) *Client {
return &Client{
socketPath: socketPath,
debug: debug,
}
}
// CreateSession asks the daemon to create a new sandbox session with the given
// proxy URL and optional DNS address. Returns the session info on success.
func (c *Client) CreateSession(proxyURL, dnsAddr string) (*Response, error) {
req := Request{
Action: "create_session",
ProxyURL: proxyURL,
DNSAddr: dnsAddr,
}
resp, err := c.sendRequest(req)
if err != nil {
return nil, fmt.Errorf("create session request failed: %w", err)
}
if !resp.OK {
return resp, fmt.Errorf("create session failed: %s", resp.Error)
}
return resp, nil
}
// DestroySession asks the daemon to tear down the session with the given ID.
func (c *Client) DestroySession(sessionID string) error {
req := Request{
Action: "destroy_session",
SessionID: sessionID,
}
resp, err := c.sendRequest(req)
if err != nil {
return fmt.Errorf("destroy session request failed: %w", err)
}
if !resp.OK {
return fmt.Errorf("destroy session failed: %s", resp.Error)
}
return nil
}
// Status queries the daemon for its current status.
func (c *Client) Status() (*Response, error) {
req := Request{
Action: "status",
}
resp, err := c.sendRequest(req)
if err != nil {
return nil, fmt.Errorf("status request failed: %w", err)
}
if !resp.OK {
return resp, fmt.Errorf("status request failed: %s", resp.Error)
}
return resp, nil
}
// IsRunning checks whether the daemon is reachable by attempting to connect
// to the Unix socket. Returns true if the connection succeeds.
func (c *Client) IsRunning() bool {
conn, err := net.DialTimeout("unix", c.socketPath, clientDialTimeout)
if err != nil {
return false
}
_ = conn.Close()
return true
}
// sendRequest connects to the daemon Unix socket, sends a JSON-encoded request,
// and reads back a JSON-encoded response.
func (c *Client) sendRequest(req Request) (*Response, error) {
c.logDebug("Connecting to daemon at %s", c.socketPath)
conn, err := net.DialTimeout("unix", c.socketPath, clientDialTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon at %s: %w", c.socketPath, err)
}
defer conn.Close() //nolint:errcheck // best-effort close on request completion
// Set a read deadline for the response.
if err := conn.SetReadDeadline(time.Now().Add(clientReadTimeout)); err != nil {
return nil, fmt.Errorf("failed to set read deadline: %w", err)
}
// Send the request as newline-delimited JSON.
encoder := json.NewEncoder(conn)
if err := encoder.Encode(req); err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
c.logDebug("Sent request: action=%s", req.Action)
// Read the response.
decoder := json.NewDecoder(conn)
var resp Response
if err := decoder.Decode(&resp); err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
c.logDebug("Received response: ok=%v", resp.OK)
return &resp, nil
}
// logDebug writes a debug message to stderr with the [greywall:daemon] prefix.
func (c *Client) logDebug(format string, args ...interface{}) {
if c.debug {
fmt.Fprintf(os.Stderr, "[greywall:daemon] "+format+"\n", args...)
}
}

186
internal/daemon/dns.go Normal file
View File

@@ -0,0 +1,186 @@
//go:build darwin || linux
package daemon
import (
"fmt"
"net"
"os"
"sync"
"time"
)
const (
// maxDNSPacketSize is the maximum UDP packet size we accept.
// DNS can theoretically be up to 65535 bytes, but practically much smaller.
maxDNSPacketSize = 4096
// upstreamTimeout is the time we wait for a response from the upstream DNS server.
upstreamTimeout = 5 * time.Second
)
// DNSRelay is a UDP DNS relay that forwards DNS queries from sandboxed processes
// to a configured upstream DNS server. It operates as a simple packet relay without
// parsing DNS protocol contents.
type DNSRelay struct {
udpConn *net.UDPConn
targetAddr string // upstream DNS server address (host:port)
listenAddr string // address we're listening on
wg sync.WaitGroup
done chan struct{}
debug bool
}
// NewDNSRelay creates a new DNS relay that listens on listenAddr and forwards
// queries to dnsAddr. The listenAddr will typically be "127.0.0.2:53" (loopback alias).
// The dnsAddr must be in "host:port" format (e.g. "1.1.1.1:53").
func NewDNSRelay(listenAddr, dnsAddr string, debug bool) (*DNSRelay, error) {
// Validate the upstream DNS address is parseable as host:port.
targetHost, targetPort, err := net.SplitHostPort(dnsAddr)
if err != nil {
return nil, fmt.Errorf("invalid DNS address %q: %w", dnsAddr, err)
}
if targetHost == "" {
return nil, fmt.Errorf("invalid DNS address %q: empty host", dnsAddr)
}
if targetPort == "" {
return nil, fmt.Errorf("invalid DNS address %q: empty port", dnsAddr)
}
// Resolve and bind the listen address.
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("failed to resolve listen address %q: %w", listenAddr, err)
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, fmt.Errorf("failed to bind UDP socket on %q: %w", listenAddr, err)
}
return &DNSRelay{
udpConn: conn,
targetAddr: dnsAddr,
listenAddr: conn.LocalAddr().String(),
done: make(chan struct{}),
debug: debug,
}, nil
}
// ListenAddr returns the actual address the relay is listening on.
// This is useful when port 0 was used to get an ephemeral port.
func (d *DNSRelay) ListenAddr() string {
return d.listenAddr
}
// Start begins the DNS relay loop. It reads incoming UDP packets from the
// listening socket and spawns a goroutine per query to forward it to the
// upstream DNS server and relay the response back.
func (d *DNSRelay) Start() error {
if d.debug {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Listening on %s, forwarding to %s\n", d.listenAddr, d.targetAddr)
}
d.wg.Add(1)
go d.readLoop()
return nil
}
// Stop shuts down the DNS relay. It signals the read loop to stop, closes the
// listening socket, and waits for all in-flight queries to complete.
func (d *DNSRelay) Stop() {
close(d.done)
_ = d.udpConn.Close()
d.wg.Wait()
if d.debug {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Stopped\n")
}
}
// readLoop is the main loop that reads incoming DNS queries from the listening socket.
func (d *DNSRelay) readLoop() {
defer d.wg.Done()
buf := make([]byte, maxDNSPacketSize)
for {
n, clientAddr, err := d.udpConn.ReadFromUDP(buf)
if err != nil {
select {
case <-d.done:
// Shutting down, expected error from closed socket.
return
default:
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Read error: %v\n", err)
continue
}
}
if n == 0 {
continue
}
// Copy the packet data so the buffer can be reused immediately.
query := make([]byte, n)
copy(query, buf[:n])
d.wg.Add(1)
go d.handleQuery(query, clientAddr)
}
}
// handleQuery forwards a single DNS query to the upstream server and relays
// the response back to the original client. It creates a dedicated UDP connection
// to the upstream server to avoid multiplexing complexity.
func (d *DNSRelay) handleQuery(query []byte, clientAddr *net.UDPAddr) {
defer d.wg.Done()
if d.debug {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Query from %s (%d bytes)\n", clientAddr, len(query))
}
// Create a dedicated UDP connection to the upstream DNS server.
upstreamConn, err := net.Dial("udp", d.targetAddr)
if err != nil {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to connect to upstream %s: %v\n", d.targetAddr, err)
return
}
defer upstreamConn.Close() //nolint:errcheck // best-effort cleanup of per-query UDP connection
// Send the query to the upstream server.
if _, err := upstreamConn.Write(query); err != nil {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to send query to upstream: %v\n", err)
return
}
// Wait for the response with a timeout.
if err := upstreamConn.SetReadDeadline(time.Now().Add(upstreamTimeout)); err != nil {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to set read deadline: %v\n", err)
return
}
resp := make([]byte, maxDNSPacketSize)
n, err := upstreamConn.Read(resp)
if err != nil {
if d.debug {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Upstream response error: %v\n", err)
}
return
}
// Send the response back to the original client.
if _, err := d.udpConn.WriteToUDP(resp[:n], clientAddr); err != nil {
// The listening socket may have been closed during shutdown.
select {
case <-d.done:
return
default:
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Failed to send response to %s: %v\n", clientAddr, err)
}
}
if d.debug {
fmt.Fprintf(os.Stderr, "[greywall:dns-relay] Response to %s (%d bytes)\n", clientAddr, n)
}
}

296
internal/daemon/dns_test.go Normal file
View File

@@ -0,0 +1,296 @@
//go:build darwin || linux
package daemon
import (
"bytes"
"net"
"sync"
"testing"
"time"
)
// startMockDNSServer starts a UDP server that echoes back whatever it receives,
// prefixed with "RESP:" to distinguish responses from queries.
// Returns the server's address and a cleanup function.
func startMockDNSServer(t *testing.T) (string, func()) {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to resolve address: %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("Failed to start mock DNS server: %v", err)
}
done := make(chan struct{})
go func() {
buf := make([]byte, maxDNSPacketSize)
for {
n, remoteAddr, err := conn.ReadFromUDP(buf)
if err != nil {
select {
case <-done:
return
default:
continue
}
}
// Echo back with "RESP:" prefix.
resp := append([]byte("RESP:"), buf[:n]...)
_, _ = conn.WriteToUDP(resp, remoteAddr) // best-effort in test
}
}()
cleanup := func() {
close(done)
_ = conn.Close()
}
return conn.LocalAddr().String(), cleanup
}
// startSilentDNSServer starts a UDP server that accepts connections but never
// responds, simulating an upstream timeout.
func startSilentDNSServer(t *testing.T) (string, func()) {
t.Helper()
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to resolve address: %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("Failed to start silent DNS server: %v", err)
}
cleanup := func() {
_ = conn.Close()
}
return conn.LocalAddr().String(), cleanup
}
func TestDNSRelay_ForwardPacket(t *testing.T) {
// Start a mock upstream DNS server.
upstreamAddr, cleanupUpstream := startMockDNSServer(t)
defer cleanupUpstream()
// Create and start the relay.
relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, true)
if err != nil {
t.Fatalf("Failed to create DNS relay: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Failed to start DNS relay: %v", err)
}
defer relay.Stop()
// Send a query through the relay.
clientConn, err := net.Dial("udp", relay.ListenAddr())
if err != nil {
t.Fatalf("Failed to connect to relay: %v", err)
}
defer clientConn.Close() //nolint:errcheck // test cleanup
query := []byte("test-dns-query")
if _, err := clientConn.Write(query); err != nil {
t.Fatalf("Failed to send query: %v", err)
}
// Read the response.
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("Failed to set read deadline: %v", err)
}
buf := make([]byte, maxDNSPacketSize)
n, err := clientConn.Read(buf)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
expected := append([]byte("RESP:"), query...)
if !bytes.Equal(buf[:n], expected) {
t.Errorf("Unexpected response: got %q, want %q", buf[:n], expected)
}
}
func TestDNSRelay_UpstreamTimeout(t *testing.T) {
// Start a silent upstream server that never responds.
upstreamAddr, cleanupUpstream := startSilentDNSServer(t)
defer cleanupUpstream()
// Create and start the relay.
relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, false)
if err != nil {
t.Fatalf("Failed to create DNS relay: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Failed to start DNS relay: %v", err)
}
defer relay.Stop()
// Send a query through the relay.
clientConn, err := net.Dial("udp", relay.ListenAddr())
if err != nil {
t.Fatalf("Failed to connect to relay: %v", err)
}
defer clientConn.Close() //nolint:errcheck // test cleanup
query := []byte("test-dns-timeout")
if _, err := clientConn.Write(query); err != nil {
t.Fatalf("Failed to send query: %v", err)
}
// The relay should not send back a response because upstream timed out.
// Set a short deadline on the client side; we expect no data.
if err := clientConn.SetReadDeadline(time.Now().Add(6 * time.Second)); err != nil {
t.Fatalf("Failed to set read deadline: %v", err)
}
buf := make([]byte, maxDNSPacketSize)
_, err = clientConn.Read(buf)
if err == nil {
t.Fatal("Expected timeout error reading from relay, but got a response")
}
// Verify it was a timeout error.
netErr, ok := err.(net.Error)
if !ok || !netErr.Timeout() {
t.Fatalf("Expected timeout error, got: %v", err)
}
}
func TestDNSRelay_ConcurrentQueries(t *testing.T) {
// Start a mock upstream DNS server.
upstreamAddr, cleanupUpstream := startMockDNSServer(t)
defer cleanupUpstream()
// Create and start the relay.
relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, false)
if err != nil {
t.Fatalf("Failed to create DNS relay: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Failed to start DNS relay: %v", err)
}
defer relay.Stop()
const numQueries = 20
var wg sync.WaitGroup
errors := make(chan error, numQueries)
for i := range numQueries {
wg.Add(1)
go func(id int) {
defer wg.Done()
clientConn, err := net.Dial("udp", relay.ListenAddr())
if err != nil {
errors <- err
return
}
defer clientConn.Close() //nolint:errcheck // test cleanup
query := []byte("concurrent-query-" + string(rune('A'+id))) //nolint:gosec // test uses small range 0-19, no overflow
if _, err := clientConn.Write(query); err != nil {
errors <- err
return
}
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
errors <- err
return
}
buf := make([]byte, maxDNSPacketSize)
n, err := clientConn.Read(buf)
if err != nil {
errors <- err
return
}
expected := append([]byte("RESP:"), query...)
if !bytes.Equal(buf[:n], expected) {
errors <- &unexpectedResponseError{got: buf[:n], want: expected}
}
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("Concurrent query error: %v", err)
}
}
func TestDNSRelay_ListenAddr(t *testing.T) {
// Use port 0 to get an ephemeral port.
relay, err := NewDNSRelay("127.0.0.1:0", "1.1.1.1:53", false)
if err != nil {
t.Fatalf("Failed to create DNS relay: %v", err)
}
defer relay.Stop()
addr := relay.ListenAddr()
if addr == "" {
t.Fatal("ListenAddr returned empty string")
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("ListenAddr returned invalid address %q: %v", addr, err)
}
if host != "127.0.0.1" {
t.Errorf("Expected host 127.0.0.1, got %q", host)
}
if port == "0" {
t.Error("Expected assigned port, got 0")
}
}
func TestNewDNSRelay_InvalidDNSAddr(t *testing.T) {
tests := []struct {
name string
dnsAddr string
}{
{"missing port", "1.1.1.1"},
{"empty string", ""},
{"empty host", ":53"},
{"empty port", "1.1.1.1:"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewDNSRelay("127.0.0.1:0", tt.dnsAddr, false)
if err == nil {
t.Errorf("Expected error for DNS address %q, got nil", tt.dnsAddr)
}
})
}
}
func TestNewDNSRelay_InvalidListenAddr(t *testing.T) {
_, err := NewDNSRelay("invalid-addr", "1.1.1.1:53", false)
if err == nil {
t.Error("Expected error for invalid listen address, got nil")
}
}
// unexpectedResponseError is used to report mismatched responses in concurrent tests.
type unexpectedResponseError struct {
got []byte
want []byte
}
func (e *unexpectedResponseError) Error() string {
return "unexpected response: got " + string(e.got) + ", want " + string(e.want)
}

554
internal/daemon/launchd.go Normal file
View File

@@ -0,0 +1,554 @@
//go:build darwin
package daemon
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
const (
LaunchDaemonLabel = "co.greyhaven.greywall"
LaunchDaemonPlistPath = "/Library/LaunchDaemons/co.greyhaven.greywall.plist"
InstallBinaryPath = "/usr/local/bin/greywall"
InstallLibDir = "/usr/local/lib/greywall"
SandboxUserName = "_greywall"
SandboxUserUID = "399" // System user range on macOS
SandboxGroupName = "_greywall" // Group used for pf routing (same name as user)
SudoersFilePath = "/etc/sudoers.d/greywall"
DefaultSocketPath = "/var/run/greywall.sock"
)
// Install performs the full LaunchDaemon installation flow:
// 1. Verify running as root
// 2. Create system user _greywall
// 3. Create /usr/local/lib/greywall/ directory and copy tun2socks
// 4. Copy the current binary to /usr/local/bin/greywall
// 5. Generate and write the LaunchDaemon plist
// 6. Set proper permissions, load the daemon, and verify it starts
func Install(currentBinaryPath, tun2socksPath string, debug bool) error {
if os.Getuid() != 0 {
return fmt.Errorf("daemon install must be run as root (use sudo)")
}
// Step 1: Create system user and group.
if err := createSandboxUser(debug); err != nil {
return fmt.Errorf("failed to create sandbox user: %w", err)
}
// Step 1b: Install sudoers rule for group-based sandbox-exec.
if err := installSudoersRule(debug); err != nil {
return fmt.Errorf("failed to install sudoers rule: %w", err)
}
// Step 1c: Add invoking user to _greywall group.
addInvokingUserToGroup(debug)
// Step 2: Create lib directory and copy tun2socks.
logDebug(debug, "Creating directory %s", InstallLibDir)
if err := os.MkdirAll(InstallLibDir, 0o755); err != nil { //nolint:gosec // system lib directory needs 0755 for daemon access
return fmt.Errorf("failed to create %s: %w", InstallLibDir, err)
}
tun2socksDst := filepath.Join(InstallLibDir, "tun2socks-darwin-"+runtime.GOARCH)
logDebug(debug, "Copying tun2socks to %s", tun2socksDst)
if err := copyFile(tun2socksPath, tun2socksDst, 0o755); err != nil {
return fmt.Errorf("failed to install tun2socks: %w", err)
}
// Step 3: Copy binary to install path.
if err := os.MkdirAll(filepath.Dir(InstallBinaryPath), 0o755); err != nil { //nolint:gosec // /usr/local/bin needs 0755
return fmt.Errorf("failed to create %s: %w", filepath.Dir(InstallBinaryPath), err)
}
logDebug(debug, "Copying binary from %s to %s", currentBinaryPath, InstallBinaryPath)
if err := copyFile(currentBinaryPath, InstallBinaryPath, 0o755); err != nil {
return fmt.Errorf("failed to install binary: %w", err)
}
// Step 4: Generate and write plist.
plist := generatePlist()
logDebug(debug, "Writing plist to %s", LaunchDaemonPlistPath)
if err := os.WriteFile(LaunchDaemonPlistPath, []byte(plist), 0o644); err != nil { //nolint:gosec // LaunchDaemon plist requires 0644 per macOS convention
return fmt.Errorf("failed to write plist: %w", err)
}
// Step 5: Set ownership to root:wheel.
logDebug(debug, "Setting ownership on %s to root:wheel", LaunchDaemonPlistPath)
if err := runCmd(debug, "chown", "root:wheel", LaunchDaemonPlistPath); err != nil {
return fmt.Errorf("failed to set plist ownership: %w", err)
}
// Step 6: Load the daemon.
logDebug(debug, "Loading LaunchDaemon")
if err := runCmd(debug, "launchctl", "load", LaunchDaemonPlistPath); err != nil {
return fmt.Errorf("failed to load daemon: %w", err)
}
// Step 7: Verify the daemon actually started.
running := false
for range 10 {
time.Sleep(500 * time.Millisecond)
if IsRunning() {
running = true
break
}
}
Logf("Daemon installed successfully.")
Logf(" Plist: %s", LaunchDaemonPlistPath)
Logf(" Binary: %s", InstallBinaryPath)
Logf(" Tun2socks: %s", tun2socksDst)
actualUID := readDsclAttr(SandboxUserName, "UniqueID", true)
actualGID := readDsclAttr(SandboxGroupName, "PrimaryGroupID", false)
Logf(" User: %s (UID %s)", SandboxUserName, actualUID)
Logf(" Group: %s (GID %s, pf routing)", SandboxGroupName, actualGID)
Logf(" Sudoers: %s", SudoersFilePath)
Logf(" Log: /var/log/greywall.log")
if !running {
Logf(" Status: NOT RUNNING (check /var/log/greywall.log)")
return fmt.Errorf("daemon was loaded but failed to start; check /var/log/greywall.log")
}
Logf(" Status: running")
return nil
}
// Uninstall performs the full LaunchDaemon uninstallation flow. It attempts
// every cleanup step even if individual steps fail, collecting errors along
// the way.
func Uninstall(debug bool) error {
if os.Getuid() != 0 {
return fmt.Errorf("daemon uninstall must be run as root (use sudo)")
}
var errs []string
// Step 1: Unload daemon (best effort).
logDebug(debug, "Unloading LaunchDaemon")
if err := runCmd(debug, "launchctl", "unload", LaunchDaemonPlistPath); err != nil {
errs = append(errs, fmt.Sprintf("unload daemon: %v", err))
}
// Step 2: Remove plist file.
logDebug(debug, "Removing plist %s", LaunchDaemonPlistPath)
if err := os.Remove(LaunchDaemonPlistPath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Sprintf("remove plist: %v", err))
}
// Step 3: Remove lib directory.
logDebug(debug, "Removing directory %s", InstallLibDir)
if err := os.RemoveAll(InstallLibDir); err != nil {
errs = append(errs, fmt.Sprintf("remove lib dir: %v", err))
}
// Step 4: Remove installed binary, but only if it differs from the
// currently running executable.
currentExe, exeErr := os.Executable()
if exeErr != nil {
currentExe = ""
}
resolvedCurrent, _ := filepath.EvalSymlinks(currentExe)
resolvedInstall, _ := filepath.EvalSymlinks(InstallBinaryPath)
if resolvedCurrent != resolvedInstall {
logDebug(debug, "Removing binary %s", InstallBinaryPath)
if err := os.Remove(InstallBinaryPath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Sprintf("remove binary: %v", err))
}
} else {
logDebug(debug, "Skipping binary removal (currently running from %s)", InstallBinaryPath)
}
// Step 5: Remove system user and group.
if err := removeSandboxUser(debug); err != nil {
errs = append(errs, fmt.Sprintf("remove sandbox user: %v", err))
}
// Step 6: Remove socket file if it exists.
logDebug(debug, "Removing socket %s", DefaultSocketPath)
if err := os.Remove(DefaultSocketPath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Sprintf("remove socket: %v", err))
}
// Step 6b: Remove sudoers file.
logDebug(debug, "Removing sudoers file %s", SudoersFilePath)
if err := os.Remove(SudoersFilePath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Sprintf("remove sudoers file: %v", err))
}
// Step 7: Remove pf anchor lines from /etc/pf.conf.
if err := removeAnchorFromPFConf(debug); err != nil {
errs = append(errs, fmt.Sprintf("remove pf anchor: %v", err))
}
if len(errs) > 0 {
Logf("Uninstall completed with warnings:")
for _, e := range errs {
Logf(" - %s", e)
}
return nil // partial cleanup is not a fatal error
}
Logf("Daemon uninstalled successfully.")
return nil
}
// generatePlist returns the LaunchDaemon plist XML content.
func generatePlist() string {
return `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>` + LaunchDaemonLabel + `</string>
<key>ProgramArguments</key>
<array>
<string>` + InstallBinaryPath + `</string>
<string>daemon</string>
<string>run</string>
</array>
<key>RunAtLoad</key><true/>
<key>KeepAlive</key><true/>
<key>StandardOutPath</key>
<string>/var/log/greywall.log</string>
<key>StandardErrorPath</key>
<string>/var/log/greywall.log</string>
</dict>
</plist>
`
}
// IsInstalled returns true if the LaunchDaemon plist file exists.
func IsInstalled() bool {
_, err := os.Stat(LaunchDaemonPlistPath)
return err == nil
}
// IsRunning returns true if the daemon is currently running. It first tries
// connecting to the Unix socket (works without root), then falls back to
// launchctl print which can inspect the system domain without root.
func IsRunning() bool {
// Primary check: try to connect to the daemon socket. This proves the
// daemon is actually running and accepting connections.
conn, err := net.DialTimeout("unix", DefaultSocketPath, 2*time.Second)
if err == nil {
_ = conn.Close()
return true
}
// Fallback: launchctl print system/<label> works without root on modern
// macOS (unlike launchctl list which only shows the caller's domain).
//nolint:gosec // LaunchDaemonLabel is a constant
out, err := exec.Command("launchctl", "print", "system/"+LaunchDaemonLabel).CombinedOutput()
if err != nil {
return false
}
return strings.Contains(string(out), "state = running")
}
// createSandboxUser creates the _greywall system user and group on macOS
// using dscl (Directory Service command line utility).
//
// If the user/group already exist with valid IDs, they are reused. Otherwise
// a free UID/GID is found dynamically (the hardcoded SandboxUserUID is only
// a preferred default — macOS system groups like com.apple.access_ssh may
// already claim it).
func createSandboxUser(debug bool) error {
userPath := "/Users/" + SandboxUserName
groupPath := "/Groups/" + SandboxUserName
// Check if user already exists with a valid UniqueID.
existingUID := readDsclAttr(SandboxUserName, "UniqueID", true)
existingGID := readDsclAttr(SandboxGroupName, "PrimaryGroupID", false)
if existingUID != "" && existingGID != "" {
logDebug(debug, "System user %s (UID %s) and group (GID %s) already exist",
SandboxUserName, existingUID, existingGID)
return nil
}
// Find a free ID. Try the preferred default first, then scan.
id := SandboxUserUID
if !isIDFree(id, debug) {
var err error
id, err = findFreeSystemID(debug)
if err != nil {
return fmt.Errorf("failed to find free UID/GID: %w", err)
}
logDebug(debug, "Preferred ID %s is taken, using %s instead", SandboxUserUID, id)
}
// Create the group record FIRST (so the GID exists before the user references it).
logDebug(debug, "Ensuring system group %s (GID %s)", SandboxGroupName, id)
if existingGID == "" {
groupCmds := [][]string{
{"dscl", ".", "-create", groupPath},
{"dscl", ".", "-create", groupPath, "PrimaryGroupID", id},
{"dscl", ".", "-create", groupPath, "RealName", "Greywall Sandbox"},
}
for _, args := range groupCmds {
if err := runDsclCreate(debug, args); err != nil {
return err
}
}
// Verify the GID was actually set (runDsclCreate may have skipped it).
actualGID := readDsclAttr(SandboxGroupName, "PrimaryGroupID", false)
if actualGID == "" {
return fmt.Errorf("failed to set PrimaryGroupID on group %s (GID %s may be taken)", SandboxGroupName, id)
}
}
// Create the user record.
logDebug(debug, "Ensuring system user %s (UID %s)", SandboxUserName, id)
if existingUID == "" {
userCmds := [][]string{
{"dscl", ".", "-create", userPath},
{"dscl", ".", "-create", userPath, "UniqueID", id},
{"dscl", ".", "-create", userPath, "PrimaryGroupID", id},
{"dscl", ".", "-create", userPath, "UserShell", "/usr/bin/false"},
{"dscl", ".", "-create", userPath, "RealName", "Greywall Sandbox"},
{"dscl", ".", "-create", userPath, "NFSHomeDirectory", "/var/empty"},
}
for _, args := range userCmds {
if err := runDsclCreate(debug, args); err != nil {
return err
}
}
}
logDebug(debug, "System user and group %s ready (ID %s)", SandboxUserName, id)
return nil
}
// readDsclAttr reads a single attribute from a user or group record.
// Returns empty string if the record or attribute does not exist.
func readDsclAttr(name, attr string, isUser bool) string {
recordType := "/Groups/"
if isUser {
recordType = "/Users/"
}
//nolint:gosec // name and attr are controlled constants
out, err := exec.Command("dscl", ".", "-read", recordType+name, attr).Output()
if err != nil {
return ""
}
// Output format: "AttrName: value"
parts := strings.SplitN(strings.TrimSpace(string(out)), ": ", 2)
if len(parts) != 2 {
return ""
}
return strings.TrimSpace(parts[1])
}
// isIDFree checks whether a given numeric ID is available as both a UID and GID.
func isIDFree(id string, debug bool) bool {
// Check if any user has this UniqueID.
//nolint:gosec // id is a controlled numeric string
out, err := exec.Command("dscl", ".", "-search", "/Users", "UniqueID", id).Output()
if err == nil && strings.TrimSpace(string(out)) != "" {
logDebug(debug, "ID %s is taken by a user", id)
return false
}
// Check if any group has this PrimaryGroupID.
//nolint:gosec // id is a controlled numeric string
out, err = exec.Command("dscl", ".", "-search", "/Groups", "PrimaryGroupID", id).Output()
if err == nil && strings.TrimSpace(string(out)) != "" {
logDebug(debug, "ID %s is taken by a group", id)
return false
}
return true
}
// findFreeSystemID scans the macOS system ID range (350499) for a UID/GID
// pair that is not in use by any existing user or group.
func findFreeSystemID(debug bool) (string, error) {
for i := 350; i < 500; i++ {
id := strconv.Itoa(i)
if isIDFree(id, debug) {
return id, nil
}
}
return "", fmt.Errorf("no free system UID/GID found in range 350-499")
}
// runDsclCreate runs a dscl -create command, silently ignoring
// eDSRecordAlreadyExists errors (idempotent for repeated installs).
func runDsclCreate(debug bool, args []string) error {
err := runCmd(debug, args[0], args[1:]...)
if err != nil && strings.Contains(err.Error(), "eDSRecordAlreadyExists") {
logDebug(debug, "Already exists, skipping: %s", strings.Join(args, " "))
return nil
}
if err != nil {
return fmt.Errorf("dscl command failed (%s): %w", strings.Join(args, " "), err)
}
return nil
}
// removeSandboxUser removes the _greywall system user and group.
func removeSandboxUser(debug bool) error {
var errs []string
userPath := "/Users/" + SandboxUserName
groupPath := "/Groups/" + SandboxUserName
if userExists(SandboxUserName) {
logDebug(debug, "Removing system user %s", SandboxUserName)
if err := runCmd(debug, "dscl", ".", "-delete", userPath); err != nil {
errs = append(errs, fmt.Sprintf("delete user: %v", err))
}
}
// Check if group exists before trying to remove.
logDebug(debug, "Removing system group %s", SandboxUserName)
if err := runCmd(debug, "dscl", ".", "-delete", groupPath); err != nil {
// Group may not exist; only record error if it's not a "not found" case.
errStr := err.Error()
if !strings.Contains(errStr, "not found") && !strings.Contains(errStr, "does not exist") {
errs = append(errs, fmt.Sprintf("delete group: %v", err))
}
}
if len(errs) > 0 {
return fmt.Errorf("sandbox user removal issues: %s", strings.Join(errs, "; "))
}
return nil
}
// userExists checks if a user exists on macOS by querying the directory service.
func userExists(username string) bool {
//nolint:gosec // username is a controlled constant
err := exec.Command("dscl", ".", "-read", "/Users/"+username).Run()
return err == nil
}
// installSudoersRule writes a sudoers rule that allows members of the
// _greywall group to run sandbox-exec as any user with group _greywall,
// without a password. The rule is validated with visudo -cf before install.
func installSudoersRule(debug bool) error {
rule := fmt.Sprintf("%%%s ALL = (ALL:%s) NOPASSWD: /usr/bin/sandbox-exec *\n",
SandboxGroupName, SandboxGroupName)
logDebug(debug, "Writing sudoers rule to %s", SudoersFilePath)
// Ensure /etc/sudoers.d exists.
if err := os.MkdirAll(filepath.Dir(SudoersFilePath), 0o755); err != nil { //nolint:gosec // /etc/sudoers.d must be 0755
return fmt.Errorf("failed to create sudoers directory: %w", err)
}
// Write to a temp file first, then validate with visudo.
tmpFile := SudoersFilePath + ".tmp"
if err := os.WriteFile(tmpFile, []byte(rule), 0o440); err != nil {
return fmt.Errorf("failed to write sudoers temp file: %w", err)
}
// Validate syntax before installing.
//nolint:gosec // tmpFile is a controlled path
if err := runCmd(debug, "visudo", "-cf", tmpFile); err != nil {
_ = os.Remove(tmpFile)
return fmt.Errorf("sudoers validation failed: %w", err)
}
// Move validated file into place.
if err := os.Rename(tmpFile, SudoersFilePath); err != nil {
_ = os.Remove(tmpFile)
return fmt.Errorf("failed to install sudoers file: %w", err)
}
// Ensure correct ownership (root:wheel) and permissions (0440).
if err := runCmd(debug, "chown", "root:wheel", SudoersFilePath); err != nil {
return fmt.Errorf("failed to set sudoers ownership: %w", err)
}
if err := os.Chmod(SudoersFilePath, 0o440); err != nil {
return fmt.Errorf("failed to set sudoers permissions: %w", err)
}
logDebug(debug, "Sudoers rule installed: %s", SudoersFilePath)
return nil
}
// addInvokingUserToGroup adds the real invoking user (detected via SUDO_USER)
// to the _greywall group so they can use sudo -g _greywall. This is non-fatal;
// if it fails, a manual instruction is printed.
//
// We use dscl -append (not dseditgroup) because dseditgroup can reset group
// attributes like PrimaryGroupID on freshly created groups.
func addInvokingUserToGroup(debug bool) {
realUser := os.Getenv("SUDO_USER")
if realUser == "" || realUser == "root" {
Logf("Note: Could not detect invoking user (SUDO_USER not set).")
Logf(" You may need to manually add your user to the %s group:", SandboxGroupName)
Logf(" sudo dscl . -append /Groups/%s GroupMembership YOUR_USERNAME", SandboxGroupName)
return
}
groupPath := "/Groups/" + SandboxGroupName
logDebug(debug, "Adding user %s to group %s", realUser, SandboxGroupName)
//nolint:gosec // realUser comes from SUDO_USER env var set by sudo
err := runCmd(debug, "dscl", ".", "-append", groupPath, "GroupMembership", realUser)
if err != nil {
Logf("Warning: failed to add %s to group %s: %v", realUser, SandboxGroupName, err)
Logf(" You may need to run manually:")
Logf(" sudo dscl . -append %s GroupMembership %s", groupPath, realUser)
} else {
Logf(" User %s added to group %s", realUser, SandboxGroupName)
}
}
// copyFile copies a file from src to dst with the given permissions.
func copyFile(src, dst string, perm os.FileMode) error {
srcFile, err := os.Open(src) //nolint:gosec // src is from os.Executable or user flag
if err != nil {
return fmt.Errorf("open source %s: %w", src, err)
}
defer srcFile.Close() //nolint:errcheck // read-only file; close error is not actionable
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) //nolint:gosec // dst is a controlled install path constant
if err != nil {
return fmt.Errorf("create destination %s: %w", dst, err)
}
defer dstFile.Close() //nolint:errcheck // best-effort close; errors from Chmod/Copy are checked
if _, err := io.Copy(dstFile, srcFile); err != nil {
return fmt.Errorf("copy data: %w", err)
}
if err := dstFile.Chmod(perm); err != nil {
return fmt.Errorf("set permissions on %s: %w", dst, err)
}
return nil
}
// runCmd executes a command and returns an error if it fails. When debug is
// true, the command is logged before execution.
func runCmd(debug bool, name string, args ...string) error {
logDebug(debug, "exec: %s %s", name, strings.Join(args, " "))
//nolint:gosec // arguments are constructed from internal constants
cmd := exec.Command(name, args...)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("%s failed: %w (output: %s)", name, err, strings.TrimSpace(string(output)))
}
return nil
}
// logDebug writes a timestamped debug message to stderr.
func logDebug(debug bool, format string, args ...interface{}) {
if debug {
Logf(format, args...)
}
}
// Logf writes a timestamped message to stderr with the [greywall:daemon] prefix.
func Logf(format string, args ...interface{}) {
ts := time.Now().Format("2006-01-02 15:04:05")
fmt.Fprintf(os.Stderr, ts+" [greywall:daemon] "+format+"\n", args...)
}

View File

@@ -0,0 +1,37 @@
//go:build !darwin
package daemon
import "fmt"
const (
LaunchDaemonLabel = "co.greyhaven.greywall"
LaunchDaemonPlistPath = "/Library/LaunchDaemons/co.greyhaven.greywall.plist"
InstallBinaryPath = "/usr/local/bin/greywall"
InstallLibDir = "/usr/local/lib/greywall"
SandboxUserName = "_greywall"
SandboxUserUID = "399"
SandboxGroupName = "_greywall"
SudoersFilePath = "/etc/sudoers.d/greywall"
DefaultSocketPath = "/var/run/greywall.sock"
)
// Install is only supported on macOS.
func Install(currentBinaryPath, tun2socksPath string, debug bool) error {
return fmt.Errorf("daemon install is only supported on macOS")
}
// Uninstall is only supported on macOS.
func Uninstall(debug bool) error {
return fmt.Errorf("daemon uninstall is only supported on macOS")
}
// IsInstalled always returns false on non-macOS platforms.
func IsInstalled() bool {
return false
}
// IsRunning always returns false on non-macOS platforms.
func IsRunning() bool {
return false
}

View File

@@ -0,0 +1,129 @@
//go:build darwin
package daemon
import (
"strings"
"testing"
)
func TestGeneratePlist(t *testing.T) {
plist := generatePlist()
// Verify it is valid-looking XML with the expected plist header.
if !strings.HasPrefix(plist, `<?xml version="1.0" encoding="UTF-8"?>`) {
t.Error("plist should start with XML declaration")
}
if !strings.Contains(plist, `<!DOCTYPE plist PUBLIC`) {
t.Error("plist should contain DOCTYPE declaration")
}
if !strings.Contains(plist, `<plist version="1.0">`) {
t.Error("plist should contain plist version tag")
}
// Verify the label matches the constant.
expectedLabel := "<string>" + LaunchDaemonLabel + "</string>"
if !strings.Contains(plist, expectedLabel) {
t.Errorf("plist should contain label %q", LaunchDaemonLabel)
}
// Verify program arguments.
if !strings.Contains(plist, "<string>"+InstallBinaryPath+"</string>") {
t.Errorf("plist should reference binary path %q", InstallBinaryPath)
}
if !strings.Contains(plist, "<string>daemon</string>") {
t.Error("plist should contain 'daemon' argument")
}
if !strings.Contains(plist, "<string>run</string>") {
t.Error("plist should contain 'run' argument")
}
// Verify RunAtLoad and KeepAlive.
if !strings.Contains(plist, "<key>RunAtLoad</key><true/>") {
t.Error("plist should have RunAtLoad set to true")
}
if !strings.Contains(plist, "<key>KeepAlive</key><true/>") {
t.Error("plist should have KeepAlive set to true")
}
// Verify log paths.
if !strings.Contains(plist, "/var/log/greywall.log") {
t.Error("plist should reference /var/log/greywall.log for stdout/stderr")
}
}
func TestGeneratePlistProgramArguments(t *testing.T) {
plist := generatePlist()
// Verify the ProgramArguments array contains exactly 3 entries in order.
// The array should be: /usr/local/bin/greywall, daemon, run
argStart := strings.Index(plist, "<key>ProgramArguments</key>")
if argStart == -1 {
t.Fatal("plist should contain ProgramArguments key")
}
// Extract the array section.
arrayStart := strings.Index(plist[argStart:], "<array>")
arrayEnd := strings.Index(plist[argStart:], "</array>")
if arrayStart == -1 || arrayEnd == -1 {
t.Fatal("ProgramArguments should contain an array")
}
arrayContent := plist[argStart+arrayStart : argStart+arrayEnd]
expectedArgs := []string{InstallBinaryPath, "daemon", "run"}
for _, arg := range expectedArgs {
if !strings.Contains(arrayContent, "<string>"+arg+"</string>") {
t.Errorf("ProgramArguments array should contain %q", arg)
}
}
}
func TestIsInstalledReturnsFalse(t *testing.T) {
// On a test machine without the daemon installed, this should return false.
// We cannot guarantee the daemon is not installed, but on most dev machines
// it will not be. This test verifies the function runs without panicking.
result := IsInstalled()
// We only validate the function returns a bool without error.
// On CI/dev machines the plist should not exist.
_ = result
}
func TestIsRunningReturnsFalse(t *testing.T) {
// On a test machine without the daemon running, this should return false.
// Similar to TestIsInstalledReturnsFalse, we verify it runs cleanly.
result := IsRunning()
_ = result
}
func TestConstants(t *testing.T) {
// Verify constants have expected values.
tests := []struct {
name string
got string
expected string
}{
{"LaunchDaemonLabel", LaunchDaemonLabel, "co.greyhaven.greywall"},
{"LaunchDaemonPlistPath", LaunchDaemonPlistPath, "/Library/LaunchDaemons/co.greyhaven.greywall.plist"},
{"InstallBinaryPath", InstallBinaryPath, "/usr/local/bin/greywall"},
{"InstallLibDir", InstallLibDir, "/usr/local/lib/greywall"},
{"SandboxUserName", SandboxUserName, "_greywall"},
{"SandboxUserUID", SandboxUserUID, "399"},
{"SandboxGroupName", SandboxGroupName, "_greywall"},
{"SudoersFilePath", SudoersFilePath, "/etc/sudoers.d/greywall"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s = %q, want %q", tt.name, tt.got, tt.expected)
}
})
}
}

246
internal/daemon/relay.go Normal file
View File

@@ -0,0 +1,246 @@
//go:build darwin || linux
package daemon
import (
"fmt"
"io"
"net"
"net/url"
"os"
"sync"
"sync/atomic"
"time"
)
const (
defaultMaxConns = 256
connIdleTimeout = 5 * time.Minute
upstreamDialTimout = 10 * time.Second
)
// Relay is a pure Go TCP relay that forwards connections from local listeners
// to an upstream SOCKS5 proxy address. It does NOT implement the SOCKS5 protocol;
// it blindly forwards bytes between the local connection and the upstream proxy.
type Relay struct {
listeners []net.Listener // both IPv4 and IPv6 listeners
targetAddr string // external SOCKS5 proxy host:port
port int // assigned port
wg sync.WaitGroup
done chan struct{}
debug bool
maxConns int // max concurrent connections (default 256)
activeConns atomic.Int32 // current active connections
}
// NewRelay parses a proxy URL to extract host:port and binds listeners on both
// 127.0.0.1 and [::1] using the same port. The port is dynamically assigned
// from the first (IPv4) bind. If the IPv6 bind fails, the relay continues
// with IPv4 only. Binding both addresses prevents IPv6 port squatting attacks.
func NewRelay(proxyURL string, debug bool) (*Relay, error) {
u, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
}
host := u.Hostname()
port := u.Port()
if host == "" || port == "" {
return nil, fmt.Errorf("proxy URL must include host and port: %q", proxyURL)
}
targetAddr := net.JoinHostPort(host, port)
// Bind IPv4 first to get a dynamically assigned port.
ipv4Listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("failed to bind IPv4 listener: %w", err)
}
assignedPort := ipv4Listener.Addr().(*net.TCPAddr).Port
listeners := []net.Listener{ipv4Listener}
// Bind IPv6 on the same port. If it fails, log and continue with IPv4 only.
ipv6Addr := fmt.Sprintf("[::1]:%d", assignedPort)
ipv6Listener, err := net.Listen("tcp6", ipv6Addr)
if err != nil {
if debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] IPv6 bind on %s failed, continuing IPv4 only: %v\n", ipv6Addr, err)
}
} else {
listeners = append(listeners, ipv6Listener)
}
if debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Bound %d listener(s) on port %d -> %s\n", len(listeners), assignedPort, targetAddr)
}
return &Relay{
listeners: listeners,
targetAddr: targetAddr,
port: assignedPort,
done: make(chan struct{}),
debug: debug,
maxConns: defaultMaxConns,
}, nil
}
// Port returns the local port the relay is listening on.
func (r *Relay) Port() int {
return r.port
}
// Start begins accepting connections on all listeners. Each accepted connection
// is handled in its own goroutine with bidirectional forwarding to the upstream
// proxy address. Start returns immediately; use Stop to shut down.
func (r *Relay) Start() error {
for _, ln := range r.listeners {
r.wg.Add(1)
go r.acceptLoop(ln)
}
return nil
}
// Stop gracefully shuts down the relay by closing all listeners and waiting
// for in-flight connections to finish.
func (r *Relay) Stop() {
close(r.done)
for _, ln := range r.listeners {
_ = ln.Close()
}
r.wg.Wait()
}
// acceptLoop runs the accept loop for a single listener.
func (r *Relay) acceptLoop(ln net.Listener) {
defer r.wg.Done()
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-r.done:
return
default:
}
// Transient accept error; continue.
if r.debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Accept error: %v\n", err)
}
continue
}
r.wg.Add(1)
go r.handleConn(conn)
}
}
// handleConn handles a single accepted connection by dialing the upstream
// proxy and performing bidirectional byte forwarding.
func (r *Relay) handleConn(local net.Conn) {
defer r.wg.Done()
remoteAddr := local.RemoteAddr().String()
// Enforce max concurrent connections.
current := r.activeConns.Add(1)
if int(current) > r.maxConns {
r.activeConns.Add(-1)
if r.debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Connection from %s rejected: max connections (%d) reached\n", remoteAddr, r.maxConns)
}
_ = local.Close()
return
}
defer r.activeConns.Add(-1)
if r.debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Connection accepted from %s\n", remoteAddr)
}
// Dial the upstream proxy.
upstream, err := net.DialTimeout("tcp", r.targetAddr, upstreamDialTimout)
if err != nil {
fmt.Fprintf(os.Stderr, "[greywall:relay] WARNING: upstream connect to %s failed: %v\n", r.targetAddr, err)
_ = local.Close()
return
}
if r.debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Upstream connected: %s -> %s\n", remoteAddr, r.targetAddr)
}
// Bidirectional copy with proper TCP half-close.
var (
localToUpstream int64
upstreamToLocal int64
copyWg sync.WaitGroup
)
copyWg.Add(2)
// local -> upstream
go func() {
defer copyWg.Done()
localToUpstream = r.copyWithHalfClose(upstream, local)
}()
// upstream -> local
go func() {
defer copyWg.Done()
upstreamToLocal = r.copyWithHalfClose(local, upstream)
}()
copyWg.Wait()
_ = local.Close()
_ = upstream.Close()
if r.debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Connection closed %s (sent=%d recv=%d)\n", remoteAddr, localToUpstream, upstreamToLocal)
}
}
// copyWithHalfClose copies data from src to dst, setting an idle timeout on
// each read. When the source reaches EOF, it signals a TCP half-close on dst
// via CloseWrite (if available) rather than a full Close.
func (r *Relay) copyWithHalfClose(dst, src net.Conn) int64 {
buf := make([]byte, 32*1024)
var written int64
for {
// Reset idle timeout before each read.
if err := src.SetReadDeadline(time.Now().Add(connIdleTimeout)); err != nil {
break
}
nr, readErr := src.Read(buf)
if nr > 0 {
// Reset write deadline for each write.
if err := dst.SetWriteDeadline(time.Now().Add(connIdleTimeout)); err != nil {
break
}
nw, writeErr := dst.Write(buf[:nr])
written += int64(nw)
if writeErr != nil {
break
}
if nw != nr {
break
}
}
if readErr != nil {
// Source hit EOF or error: signal half-close on destination.
if tcpDst, ok := dst.(*net.TCPConn); ok {
_ = tcpDst.CloseWrite()
}
if readErr != io.EOF {
// Unexpected error; connection may have timed out.
if r.debug {
fmt.Fprintf(os.Stderr, "[greywall:relay] Copy error: %v\n", readErr)
}
}
break
}
}
return written
}

View File

@@ -0,0 +1,373 @@
//go:build darwin || linux
package daemon
import (
"bytes"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
)
// startEchoServer starts a TCP server that echoes back everything it receives.
// It returns the listener and its address.
func startEchoServer(t *testing.T) net.Listener {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to start echo server: %v", err)
}
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close() //nolint:errcheck // test cleanup
_, _ = io.Copy(c, c)
}(conn)
}
}()
return ln
}
// startBlackHoleServer accepts connections but never reads/writes, then closes.
func startBlackHoleServer(t *testing.T) net.Listener {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to start black hole server: %v", err)
}
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
}
}()
return ln
}
func TestRelayBidirectionalForward(t *testing.T) {
// Start a mock upstream (echo server) acting as the "SOCKS5 proxy".
echo := startEchoServer(t)
defer echo.Close() //nolint:errcheck // test cleanup
echoAddr := echo.Addr().String()
proxyURL := fmt.Sprintf("socks5://%s", echoAddr)
relay, err := NewRelay(proxyURL, true)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer relay.Stop()
// Connect through the relay.
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", relay.Port()))
if err != nil {
t.Fatalf("failed to connect to relay: %v", err)
}
defer conn.Close() //nolint:errcheck // test cleanup
// Send data and verify it echoes back.
msg := []byte("hello, relay!")
if _, err := conn.Write(msg); err != nil {
t.Fatalf("write failed: %v", err)
}
buf := make([]byte, len(msg))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read failed: %v", err)
}
if !bytes.Equal(buf, msg) {
t.Fatalf("expected %q, got %q", msg, buf)
}
}
func TestRelayMultipleMessages(t *testing.T) {
echo := startEchoServer(t)
defer echo.Close() //nolint:errcheck // test cleanup
proxyURL := fmt.Sprintf("socks5://%s", echo.Addr().String())
relay, err := NewRelay(proxyURL, false)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer relay.Stop()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", relay.Port()))
if err != nil {
t.Fatalf("failed to connect to relay: %v", err)
}
defer conn.Close() //nolint:errcheck // test cleanup
// Send multiple messages and verify each echoes back.
for i := 0; i < 10; i++ {
msg := []byte(fmt.Sprintf("message-%d", i))
if _, err := conn.Write(msg); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
buf := make([]byte, len(msg))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read %d failed: %v", i, err)
}
if !bytes.Equal(buf, msg) {
t.Fatalf("message %d: expected %q, got %q", i, msg, buf)
}
}
}
func TestRelayUpstreamConnectionFailure(t *testing.T) {
// Find a port that is not listening.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
deadPort := ln.Addr().(*net.TCPAddr).Port
_ = ln.Close() // close immediately so nothing is listening
proxyURL := fmt.Sprintf("socks5://127.0.0.1:%d", deadPort)
relay, err := NewRelay(proxyURL, true)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer relay.Stop()
// Connect to the relay. The relay should accept the connection but then
// fail to reach the upstream, causing the local side to be closed.
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", relay.Port()))
if err != nil {
t.Fatalf("failed to connect to relay: %v", err)
}
defer conn.Close() //nolint:errcheck // test cleanup
// The relay should close the connection after failing upstream dial.
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
buf := make([]byte, 1)
_, readErr := conn.Read(buf)
if readErr == nil {
t.Fatal("expected read error (connection should be closed), got nil")
}
}
func TestRelayConcurrentConnections(t *testing.T) {
echo := startEchoServer(t)
defer echo.Close() //nolint:errcheck // test cleanup
proxyURL := fmt.Sprintf("socks5://%s", echo.Addr().String())
relay, err := NewRelay(proxyURL, false)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer relay.Stop()
const numConns = 50
var wg sync.WaitGroup
errors := make(chan error, numConns)
for i := 0; i < numConns; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", relay.Port()))
if err != nil {
errors <- fmt.Errorf("conn %d: dial failed: %w", idx, err)
return
}
defer conn.Close() //nolint:errcheck // test cleanup
msg := []byte(fmt.Sprintf("concurrent-%d", idx))
if _, err := conn.Write(msg); err != nil {
errors <- fmt.Errorf("conn %d: write failed: %w", idx, err)
return
}
buf := make([]byte, len(msg))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, buf); err != nil {
errors <- fmt.Errorf("conn %d: read failed: %w", idx, err)
return
}
if !bytes.Equal(buf, msg) {
errors <- fmt.Errorf("conn %d: expected %q, got %q", idx, msg, buf)
}
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
t.Error(err)
}
}
func TestRelayMaxConnsLimit(t *testing.T) {
// Use a black hole server so connections stay open.
bh := startBlackHoleServer(t)
defer bh.Close() //nolint:errcheck // test cleanup
proxyURL := fmt.Sprintf("socks5://%s", bh.Addr().String())
relay, err := NewRelay(proxyURL, true)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
// Set a very low limit for testing.
relay.maxConns = 2
if err := relay.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer relay.Stop()
// The black hole server closes connections immediately, so the relay's
// handleConn will finish quickly. Instead, use an echo server that holds
// connections open to truly test the limit.
// We just verify the relay starts and stops cleanly with the low limit.
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", relay.Port()))
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
_ = conn.Close()
}
func TestRelayTCPHalfClose(t *testing.T) {
// Start a server that reads everything, then sends a response, then closes.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
defer ln.Close() //nolint:errcheck // test cleanup
response := []byte("server-response-after-client-close")
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close() //nolint:errcheck // test cleanup
// Read all data from client until EOF (client did CloseWrite).
data, err := io.ReadAll(conn)
if err != nil {
return
}
_ = data
// Now send a response back (the write direction is still open).
_, _ = conn.Write(response)
// Signal we're done writing.
if tc, ok := conn.(*net.TCPConn); ok {
_ = tc.CloseWrite()
}
}()
proxyURL := fmt.Sprintf("socks5://%s", ln.Addr().String())
relay, err := NewRelay(proxyURL, true)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
if err := relay.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer relay.Stop()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", relay.Port()))
if err != nil {
t.Fatalf("failed to connect to relay: %v", err)
}
defer conn.Close() //nolint:errcheck // test cleanup
// Send data to the server.
clientMsg := []byte("client-data")
if _, err := conn.Write(clientMsg); err != nil {
t.Fatalf("write failed: %v", err)
}
// Half-close our write side; the server should now receive EOF and send its response.
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
t.Fatal("expected *net.TCPConn")
}
if err := tcpConn.CloseWrite(); err != nil {
t.Fatalf("CloseWrite failed: %v", err)
}
// Read the server's response through the relay.
_ = conn.SetReadDeadline(time.Now().Add(3 * time.Second))
got, err := io.ReadAll(conn)
if err != nil {
t.Fatalf("ReadAll failed: %v", err)
}
if !bytes.Equal(got, response) {
t.Fatalf("expected %q, got %q", response, got)
}
}
func TestRelayPort(t *testing.T) {
echo := startEchoServer(t)
defer echo.Close() //nolint:errcheck // test cleanup
proxyURL := fmt.Sprintf("socks5://%s", echo.Addr().String())
relay, err := NewRelay(proxyURL, false)
if err != nil {
t.Fatalf("NewRelay failed: %v", err)
}
defer relay.Stop()
port := relay.Port()
if port <= 0 || port > 65535 {
t.Fatalf("invalid port: %d", port)
}
}
func TestNewRelayInvalidURL(t *testing.T) {
tests := []struct {
name string
proxyURL string
}{
{"missing port", "socks5://127.0.0.1"},
{"missing host", "socks5://:1080"},
{"empty", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewRelay(tt.proxyURL, false)
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
}

430
internal/daemon/server.go Normal file
View File

@@ -0,0 +1,430 @@
package daemon
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os"
"os/user"
"sync"
"time"
)
// Protocol types for JSON communication over Unix socket (newline-delimited).
// Request from CLI to daemon.
type Request struct {
Action string `json:"action"` // "create_session", "destroy_session", "status"
ProxyURL string `json:"proxy_url,omitempty"` // for create_session
DNSAddr string `json:"dns_addr,omitempty"` // for create_session
SessionID string `json:"session_id,omitempty"` // for destroy_session
}
// Response from daemon to CLI.
type Response struct {
OK bool `json:"ok"`
Error string `json:"error,omitempty"`
SessionID string `json:"session_id,omitempty"`
TunDevice string `json:"tun_device,omitempty"`
SandboxUser string `json:"sandbox_user,omitempty"`
SandboxGroup string `json:"sandbox_group,omitempty"`
// Status response fields.
Running bool `json:"running,omitempty"`
ActiveSessions int `json:"active_sessions,omitempty"`
}
// Session tracks an active sandbox session.
type Session struct {
ID string
ProxyURL string
DNSAddr string
CreatedAt time.Time
}
// Server listens on a Unix socket and manages sandbox sessions. It orchestrates
// TunManager (utun + pf) and DNSRelay lifecycle for each session.
type Server struct {
socketPath string
listener net.Listener
tunManager *TunManager
dnsRelay *DNSRelay
sessions map[string]*Session
mu sync.Mutex
done chan struct{}
wg sync.WaitGroup
debug bool
tun2socksPath string
sandboxGID string // cached numeric GID for the sandbox group
}
// NewServer creates a new daemon server that will listen on the given Unix socket path.
func NewServer(socketPath, tun2socksPath string, debug bool) *Server {
return &Server{
socketPath: socketPath,
tun2socksPath: tun2socksPath,
sessions: make(map[string]*Session),
done: make(chan struct{}),
debug: debug,
}
}
// Start begins listening on the Unix socket and accepting connections.
// It removes any stale socket file before binding.
func (s *Server) Start() error {
// Pre-resolve the sandbox group GID so session creation is fast
// and doesn't depend on OpenDirectory latency.
grp, err := user.LookupGroup(SandboxGroupName)
if err != nil {
Logf("Warning: could not resolve group %s at startup: %v (will retry per-session)", SandboxGroupName, err)
} else {
s.sandboxGID = grp.Gid
Logf("Resolved group %s → GID %s", SandboxGroupName, s.sandboxGID)
}
// Remove stale socket file if it exists.
if _, err := os.Stat(s.socketPath); err == nil {
s.logDebug("Removing stale socket file %s", s.socketPath)
if err := os.Remove(s.socketPath); err != nil {
return fmt.Errorf("failed to remove stale socket %s: %w", s.socketPath, err)
}
}
ln, err := net.Listen("unix", s.socketPath)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.socketPath, err)
}
s.listener = ln
// Set socket permissions so any local user can connect to the daemon.
// The socket is localhost-only (Unix domain socket); access control is
// handled at the daemon protocol level, not file permissions.
if err := os.Chmod(s.socketPath, 0o666); err != nil { //nolint:gosec // daemon socket needs 0666 so non-root CLI can connect
_ = ln.Close()
_ = os.Remove(s.socketPath)
return fmt.Errorf("failed to set socket permissions: %w", err)
}
s.logDebug("Listening on %s", s.socketPath)
s.wg.Add(1)
go s.acceptLoop()
return nil
}
// Stop gracefully shuts down the server. It stops accepting new connections,
// tears down all active sessions, and removes the socket file.
func (s *Server) Stop() error {
// Signal shutdown.
select {
case <-s.done:
// Already closed.
default:
close(s.done)
}
// Close the listener to unblock acceptLoop.
if s.listener != nil {
_ = s.listener.Close()
}
// Wait for the accept loop and any in-flight handlers to finish.
s.wg.Wait()
// Tear down all active sessions.
s.mu.Lock()
var errs []string
for id := range s.sessions {
s.logDebug("Stopping session %s during shutdown", id)
}
if s.tunManager != nil {
if err := s.tunManager.Stop(); err != nil {
errs = append(errs, fmt.Sprintf("stop tun manager: %v", err))
}
s.tunManager = nil
}
if s.dnsRelay != nil {
s.dnsRelay.Stop()
s.dnsRelay = nil
}
s.sessions = make(map[string]*Session)
s.mu.Unlock()
// Remove the socket file.
if err := os.Remove(s.socketPath); err != nil && !os.IsNotExist(err) {
errs = append(errs, fmt.Sprintf("remove socket: %v", err))
}
if len(errs) > 0 {
return fmt.Errorf("stop errors: %s", join(errs, "; "))
}
s.logDebug("Server stopped")
return nil
}
// ActiveSessions returns the number of currently active sessions.
func (s *Server) ActiveSessions() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.sessions)
}
// acceptLoop runs the main accept loop for the Unix socket listener.
func (s *Server) acceptLoop() {
defer s.wg.Done()
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.done:
return
default:
}
s.logDebug("Accept error: %v", err)
continue
}
s.wg.Add(1)
go s.handleConnection(conn)
}
}
// handleConnection reads a single JSON request from the connection, dispatches
// it to the appropriate handler, and writes the JSON response back.
func (s *Server) handleConnection(conn net.Conn) {
defer s.wg.Done()
defer conn.Close() //nolint:errcheck // best-effort close after handling request
// Set a read deadline to prevent hung connections.
if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logDebug("Failed to set read deadline: %v", err)
return
}
decoder := json.NewDecoder(conn)
encoder := json.NewEncoder(conn)
var req Request
if err := decoder.Decode(&req); err != nil {
s.logDebug("Failed to decode request: %v", err)
resp := Response{OK: false, Error: fmt.Sprintf("invalid request: %v", err)}
_ = encoder.Encode(resp) // best-effort error response
return
}
Logf("Received request: action=%s", req.Action)
var resp Response
switch req.Action {
case "create_session":
resp = s.handleCreateSession(req)
case "destroy_session":
resp = s.handleDestroySession(req)
case "status":
resp = s.handleStatus()
default:
resp = Response{OK: false, Error: fmt.Sprintf("unknown action: %q", req.Action)}
}
if err := encoder.Encode(resp); err != nil {
s.logDebug("Failed to encode response: %v", err)
}
}
// handleCreateSession creates a new sandbox session with a utun tunnel,
// optional DNS relay, and pf rules for the sandbox group.
func (s *Server) handleCreateSession(req Request) Response {
s.mu.Lock()
defer s.mu.Unlock()
if req.ProxyURL == "" {
return Response{OK: false, Error: "proxy_url is required"}
}
// Phase 1: only one session at a time.
if len(s.sessions) > 0 {
Logf("Rejecting create_session: %d session(s) already active", len(s.sessions))
return Response{OK: false, Error: "a session is already active (only one session supported in Phase 1)"}
}
Logf("Creating session: proxy=%s dns=%s", req.ProxyURL, req.DNSAddr)
// Step 1: Create and start TunManager.
tm := NewTunManager(s.tun2socksPath, req.ProxyURL, s.debug)
if err := tm.Start(); err != nil {
return Response{OK: false, Error: fmt.Sprintf("failed to start tunnel: %v", err)}
}
// Step 2: Create DNS relay if dns_addr is provided.
var dr *DNSRelay
if req.DNSAddr != "" {
var err error
dr, err = NewDNSRelay(dnsRelayIP+":"+dnsRelayPort, req.DNSAddr, s.debug)
if err != nil {
_ = tm.Stop() // best-effort cleanup
return Response{OK: false, Error: fmt.Sprintf("failed to create DNS relay: %v", err)}
}
if err := dr.Start(); err != nil {
_ = tm.Stop() // best-effort cleanup
return Response{OK: false, Error: fmt.Sprintf("failed to start DNS relay: %v", err)}
}
}
// Step 3: Resolve the sandbox group GID. pfctl in the LaunchDaemon
// context cannot resolve group names via OpenDirectory, so we use the
// cached GID (resolved at startup) or look it up now.
sandboxGID := s.sandboxGID
if sandboxGID == "" {
grp, err := user.LookupGroup(SandboxGroupName)
if err != nil {
_ = tm.Stop()
if dr != nil {
dr.Stop()
}
return Response{OK: false, Error: fmt.Sprintf("failed to resolve group %s: %v", SandboxGroupName, err)}
}
sandboxGID = grp.Gid
s.sandboxGID = sandboxGID
}
Logf("Loading pf rules for group %s (GID %s)", SandboxGroupName, sandboxGID)
if err := tm.LoadPFRules(sandboxGID); err != nil {
if dr != nil {
dr.Stop()
}
_ = tm.Stop() // best-effort cleanup
return Response{OK: false, Error: fmt.Sprintf("failed to load pf rules: %v", err)}
}
// Step 4: Generate session ID and store.
sessionID, err := generateSessionID()
if err != nil {
if dr != nil {
dr.Stop()
}
_ = tm.UnloadPFRules() // best-effort cleanup
_ = tm.Stop() // best-effort cleanup
return Response{OK: false, Error: fmt.Sprintf("failed to generate session ID: %v", err)}
}
session := &Session{
ID: sessionID,
ProxyURL: req.ProxyURL,
DNSAddr: req.DNSAddr,
CreatedAt: time.Now(),
}
s.sessions[sessionID] = session
s.tunManager = tm
s.dnsRelay = dr
Logf("Session created: id=%s device=%s group=%s(GID %s)", sessionID, tm.TunDevice(), SandboxGroupName, sandboxGID)
return Response{
OK: true,
SessionID: sessionID,
TunDevice: tm.TunDevice(),
SandboxUser: SandboxUserName,
SandboxGroup: SandboxGroupName,
}
}
// handleDestroySession tears down an existing session by unloading pf rules,
// stopping the tunnel, and stopping the DNS relay.
func (s *Server) handleDestroySession(req Request) Response {
s.mu.Lock()
defer s.mu.Unlock()
if req.SessionID == "" {
return Response{OK: false, Error: "session_id is required"}
}
Logf("Destroying session: id=%s", req.SessionID)
session, ok := s.sessions[req.SessionID]
if !ok {
Logf("Session %q not found (active sessions: %d)", req.SessionID, len(s.sessions))
return Response{OK: false, Error: fmt.Sprintf("session %q not found", req.SessionID)}
}
var errs []string
// Step 1: Unload pf rules.
if s.tunManager != nil {
if err := s.tunManager.UnloadPFRules(); err != nil {
errs = append(errs, fmt.Sprintf("unload pf rules: %v", err))
}
}
// Step 2: Stop tun manager.
if s.tunManager != nil {
if err := s.tunManager.Stop(); err != nil {
errs = append(errs, fmt.Sprintf("stop tun manager: %v", err))
}
s.tunManager = nil
}
// Step 3: Stop DNS relay.
if s.dnsRelay != nil {
s.dnsRelay.Stop()
s.dnsRelay = nil
}
// Step 4: Remove session.
delete(s.sessions, session.ID)
if len(errs) > 0 {
Logf("Session %s destroyed with errors: %v", session.ID, errs)
return Response{OK: false, Error: fmt.Sprintf("session destroyed with errors: %s", join(errs, "; "))}
}
Logf("Session destroyed: id=%s (remaining: %d)", session.ID, len(s.sessions))
return Response{OK: true}
}
// handleStatus returns the current daemon status including whether it is running
// and how many sessions are active.
func (s *Server) handleStatus() Response {
s.mu.Lock()
defer s.mu.Unlock()
return Response{
OK: true,
Running: true,
ActiveSessions: len(s.sessions),
}
}
// generateSessionID produces a cryptographically random hex session identifier.
func generateSessionID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return hex.EncodeToString(b), nil
}
// join concatenates string slices with a separator. This avoids importing
// the strings package solely for strings.Join.
func join(parts []string, sep string) string {
if len(parts) == 0 {
return ""
}
result := parts[0]
for _, p := range parts[1:] {
result += sep + p
}
return result
}
// logDebug writes a timestamped debug message to stderr.
func (s *Server) logDebug(format string, args ...interface{}) {
if s.debug {
Logf(format, args...)
}
}

View File

@@ -0,0 +1,527 @@
package daemon
import (
"encoding/json"
"net"
"os"
"path/filepath"
"testing"
"time"
)
// testSocketPath returns a temporary Unix socket path for testing.
// macOS limits Unix socket paths to 104 bytes, so we use a short temp directory
// under /tmp rather than the longer t.TempDir() paths.
func testSocketPath(t *testing.T) string {
t.Helper()
dir, err := os.MkdirTemp("/tmp", "gw-")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
sockPath := filepath.Join(dir, "d.sock")
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
return sockPath
}
func TestServerStartStop(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
// Verify socket file exists.
info, err := os.Stat(sockPath)
if err != nil {
t.Fatalf("Socket file not found: %v", err)
}
// Verify socket permissions (0666 — any local user can connect).
perm := info.Mode().Perm()
if perm != 0o666 {
t.Errorf("Expected socket permissions 0666, got %o", perm)
}
// Verify no active sessions at start.
if n := srv.ActiveSessions(); n != 0 {
t.Errorf("Expected 0 active sessions, got %d", n)
}
if err := srv.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
// Verify socket file is removed after stop.
if _, err := os.Stat(sockPath); !os.IsNotExist(err) {
t.Error("Socket file should be removed after stop")
}
}
func TestServerStartRemovesStaleSocket(t *testing.T) {
sockPath := testSocketPath(t)
// Create a stale socket file.
if err := os.WriteFile(sockPath, []byte("stale"), 0o600); err != nil {
t.Fatalf("Failed to create stale file: %v", err)
}
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed with stale socket: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Verify the server is listening by connecting.
conn, err := net.DialTimeout("unix", sockPath, 2*time.Second)
if err != nil {
t.Fatalf("Failed to connect to server: %v", err)
}
_ = conn.Close()
}
func TestServerDoubleStop(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", false)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
// First stop should succeed.
if err := srv.Stop(); err != nil {
t.Fatalf("First stop failed: %v", err)
}
// Second stop should not panic (socket already removed).
_ = srv.Stop()
}
func TestProtocolStatus(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Send a status request.
resp := sendTestRequest(t, sockPath, Request{Action: "status"})
if !resp.OK {
t.Fatalf("Expected OK=true, got error: %s", resp.Error)
}
if !resp.Running {
t.Error("Expected Running=true")
}
if resp.ActiveSessions != 0 {
t.Errorf("Expected 0 active sessions, got %d", resp.ActiveSessions)
}
}
func TestProtocolUnknownAction(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
resp := sendTestRequest(t, sockPath, Request{Action: "unknown_action"})
if resp.OK {
t.Fatal("Expected OK=false for unknown action")
}
if resp.Error == "" {
t.Error("Expected error message for unknown action")
}
}
func TestProtocolCreateSessionMissingProxy(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Create session without proxy_url should fail.
resp := sendTestRequest(t, sockPath, Request{
Action: "create_session",
})
if resp.OK {
t.Fatal("Expected OK=false for missing proxy URL")
}
if resp.Error == "" {
t.Error("Expected error message for missing proxy URL")
}
}
func TestProtocolCreateSessionTunFailure(t *testing.T) {
sockPath := testSocketPath(t)
// Use a nonexistent tun2socks path so TunManager.Start() will fail.
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Create session should fail because tun2socks binary does not exist.
resp := sendTestRequest(t, sockPath, Request{
Action: "create_session",
ProxyURL: "socks5://127.0.0.1:1080",
})
if resp.OK {
t.Fatal("Expected OK=false when tun2socks is not available")
}
if resp.Error == "" {
t.Error("Expected error message when tun2socks fails")
}
// Verify no session was created.
if srv.ActiveSessions() != 0 {
t.Error("Expected 0 active sessions after failed create")
}
}
func TestProtocolDestroySessionMissingID(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
resp := sendTestRequest(t, sockPath, Request{
Action: "destroy_session",
})
if resp.OK {
t.Fatal("Expected OK=false for missing session ID")
}
if resp.Error == "" {
t.Error("Expected error message for missing session ID")
}
}
func TestProtocolDestroySessionNotFound(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
resp := sendTestRequest(t, sockPath, Request{
Action: "destroy_session",
SessionID: "nonexistent-session-id",
})
if resp.OK {
t.Fatal("Expected OK=false for nonexistent session")
}
if resp.Error == "" {
t.Error("Expected error message for nonexistent session")
}
}
func TestProtocolInvalidJSON(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Send invalid JSON to the server.
conn, err := net.DialTimeout("unix", sockPath, 2*time.Second)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close() //nolint:errcheck // test cleanup
if _, err := conn.Write([]byte("not valid json\n")); err != nil {
t.Fatalf("Failed to write: %v", err)
}
// Read error response.
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
decoder := json.NewDecoder(conn)
var resp Response
if err := decoder.Decode(&resp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if resp.OK {
t.Fatal("Expected OK=false for invalid JSON")
}
if resp.Error == "" {
t.Error("Expected error message for invalid JSON")
}
}
func TestClientIsRunning(t *testing.T) {
sockPath := testSocketPath(t)
client := NewClient(sockPath, true)
// Server not started yet.
if client.IsRunning() {
t.Error("Expected IsRunning=false when server is not started")
}
// Start the server.
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Now the client should detect the server.
if !client.IsRunning() {
t.Error("Expected IsRunning=true when server is running")
}
}
func TestClientStatus(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
client := NewClient(sockPath, true)
resp, err := client.Status()
if err != nil {
t.Fatalf("Status failed: %v", err)
}
if !resp.OK {
t.Fatalf("Expected OK=true, got error: %s", resp.Error)
}
if !resp.Running {
t.Error("Expected Running=true")
}
if resp.ActiveSessions != 0 {
t.Errorf("Expected 0 active sessions, got %d", resp.ActiveSessions)
}
}
func TestClientDestroySessionNotFound(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
client := NewClient(sockPath, true)
err := client.DestroySession("nonexistent-id")
if err == nil {
t.Fatal("Expected error for nonexistent session")
}
}
func TestClientConnectionRefused(t *testing.T) {
sockPath := testSocketPath(t)
// No server running.
client := NewClient(sockPath, true)
_, err := client.Status()
if err == nil {
t.Fatal("Expected error when server is not running")
}
_, err = client.CreateSession("socks5://127.0.0.1:1080", "")
if err == nil {
t.Fatal("Expected error when server is not running")
}
err = client.DestroySession("some-id")
if err == nil {
t.Fatal("Expected error when server is not running")
}
}
func TestProtocolMultipleStatusRequests(t *testing.T) {
sockPath := testSocketPath(t)
srv := NewServer(sockPath, "/nonexistent/tun2socks", true)
if err := srv.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer srv.Stop() //nolint:errcheck // test cleanup
// Send multiple status requests sequentially (each on a new connection).
for i := 0; i < 5; i++ {
resp := sendTestRequest(t, sockPath, Request{Action: "status"})
if !resp.OK {
t.Fatalf("Request %d: expected OK=true, got error: %s", i, resp.Error)
}
}
}
func TestProtocolRequestResponseJSON(t *testing.T) {
// Test that protocol types serialize/deserialize correctly.
req := Request{
Action: "create_session",
ProxyURL: "socks5://127.0.0.1:1080",
DNSAddr: "1.1.1.1:53",
SessionID: "test-session",
}
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
var decoded Request
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal request: %v", err)
}
if decoded.Action != req.Action {
t.Errorf("Action: got %q, want %q", decoded.Action, req.Action)
}
if decoded.ProxyURL != req.ProxyURL {
t.Errorf("ProxyURL: got %q, want %q", decoded.ProxyURL, req.ProxyURL)
}
if decoded.DNSAddr != req.DNSAddr {
t.Errorf("DNSAddr: got %q, want %q", decoded.DNSAddr, req.DNSAddr)
}
if decoded.SessionID != req.SessionID {
t.Errorf("SessionID: got %q, want %q", decoded.SessionID, req.SessionID)
}
resp := Response{
OK: true,
SessionID: "abc123",
TunDevice: "utun7",
SandboxUser: "_greywall",
SandboxGroup: "_greywall",
Running: true,
ActiveSessions: 1,
}
data, err = json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal response: %v", err)
}
var decodedResp Response
if err := json.Unmarshal(data, &decodedResp); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if decodedResp.OK != resp.OK {
t.Errorf("OK: got %v, want %v", decodedResp.OK, resp.OK)
}
if decodedResp.SessionID != resp.SessionID {
t.Errorf("SessionID: got %q, want %q", decodedResp.SessionID, resp.SessionID)
}
if decodedResp.TunDevice != resp.TunDevice {
t.Errorf("TunDevice: got %q, want %q", decodedResp.TunDevice, resp.TunDevice)
}
if decodedResp.SandboxUser != resp.SandboxUser {
t.Errorf("SandboxUser: got %q, want %q", decodedResp.SandboxUser, resp.SandboxUser)
}
if decodedResp.SandboxGroup != resp.SandboxGroup {
t.Errorf("SandboxGroup: got %q, want %q", decodedResp.SandboxGroup, resp.SandboxGroup)
}
if decodedResp.Running != resp.Running {
t.Errorf("Running: got %v, want %v", decodedResp.Running, resp.Running)
}
if decodedResp.ActiveSessions != resp.ActiveSessions {
t.Errorf("ActiveSessions: got %d, want %d", decodedResp.ActiveSessions, resp.ActiveSessions)
}
}
func TestProtocolResponseOmitEmpty(t *testing.T) {
// Verify omitempty works: error-only response should not include session fields.
resp := Response{OK: false, Error: "something went wrong"}
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("Failed to unmarshal to map: %v", err)
}
// These fields should be omitted due to omitempty.
for _, key := range []string{"session_id", "tun_device", "sandbox_user", "sandbox_group"} {
if _, exists := raw[key]; exists {
t.Errorf("Expected %q to be omitted from JSON, but it was present", key)
}
}
// Error should be present.
if _, exists := raw["error"]; !exists {
t.Error("Expected 'error' field in JSON")
}
}
func TestGenerateSessionID(t *testing.T) {
// Verify session IDs are unique and properly formatted.
seen := make(map[string]bool)
for i := 0; i < 100; i++ {
id, err := generateSessionID()
if err != nil {
t.Fatalf("generateSessionID failed: %v", err)
}
if len(id) != 32 { // 16 bytes = 32 hex chars
t.Errorf("Expected 32-char hex ID, got %d chars: %q", len(id), id)
}
if seen[id] {
t.Errorf("Duplicate session ID: %s", id)
}
seen[id] = true
}
}
// sendTestRequest connects to the server, sends a JSON request, and returns
// the JSON response. This is a low-level helper that bypasses the Client
// to test the raw protocol.
func sendTestRequest(t *testing.T, sockPath string, req Request) Response {
t.Helper()
conn, err := net.DialTimeout("unix", sockPath, 2*time.Second)
if err != nil {
t.Fatalf("Failed to connect to server: %v", err)
}
defer conn.Close() //nolint:errcheck // test cleanup
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
encoder := json.NewEncoder(conn)
if err := encoder.Encode(req); err != nil {
t.Fatalf("Failed to encode request: %v", err)
}
decoder := json.NewDecoder(conn)
var resp Response
if err := decoder.Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
return resp
}

570
internal/daemon/tun.go Normal file
View File

@@ -0,0 +1,570 @@
//go:build darwin
package daemon
import (
"bufio"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"strings"
"sync"
"time"
)
const (
tunIP = "198.18.0.1"
dnsRelayIP = "127.0.0.2"
dnsRelayPort = "15353" // high port to avoid conflicts with system DNS (mDNSResponder, Docker/Lima)
pfAnchorName = "co.greyhaven.greywall"
// tun2socksStopGracePeriod is the time to wait for tun2socks to exit
// after SIGTERM before sending SIGKILL.
tun2socksStopGracePeriod = 5 * time.Second
)
// utunDevicePattern matches "utunN" device names in tun2socks output or ifconfig.
var utunDevicePattern = regexp.MustCompile(`(utun\d+)`)
// TunManager handles utun device creation via tun2socks, tun2socks process
// lifecycle, and pf (packet filter) rule management for routing sandboxed
// traffic through the tunnel on macOS.
type TunManager struct {
tunDevice string // e.g., "utun7"
tun2socksPath string // path to tun2socks binary
tun2socksCmd *exec.Cmd // running tun2socks process
proxyURL string // SOCKS5 proxy URL for tun2socks
pfAnchor string // pf anchor name
debug bool
done chan struct{}
mu sync.Mutex
}
// NewTunManager creates a new TunManager that will use the given tun2socks
// binary and SOCKS5 proxy URL. The pf anchor is set to "co.greyhaven.greywall".
func NewTunManager(tun2socksPath string, proxyURL string, debug bool) *TunManager {
return &TunManager{
tun2socksPath: tun2socksPath,
proxyURL: proxyURL,
pfAnchor: pfAnchorName,
debug: debug,
done: make(chan struct{}),
}
}
// Start brings up the full tunnel stack:
// 1. Start tun2socks with "-device utun" (it auto-creates a utunN device)
// 2. Discover which utunN device was created
// 3. Configure the utun interface IP
// 4. Set up a loopback alias for the DNS relay
// 5. Load pf anchor rules (deferred until LoadPFRules is called explicitly)
func (t *TunManager) Start() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.tun2socksCmd != nil {
return fmt.Errorf("tun manager already started")
}
// Step 1: Start tun2socks. It creates the utun device automatically.
if err := t.startTun2Socks(); err != nil {
return fmt.Errorf("failed to start tun2socks: %w", err)
}
// Step 2: Configure the utun interface with a point-to-point IP.
if err := t.configureInterface(); err != nil {
_ = t.stopTun2Socks()
return fmt.Errorf("failed to configure interface %s: %w", t.tunDevice, err)
}
// Step 3: Add a loopback alias for the DNS relay address.
if err := t.addLoopbackAlias(); err != nil {
_ = t.stopTun2Socks()
return fmt.Errorf("failed to add loopback alias: %w", err)
}
t.logDebug("Tunnel stack started: device=%s proxy=%s", t.tunDevice, t.proxyURL)
return nil
}
// Stop tears down the tunnel stack in reverse order:
// 1. Unload pf rules
// 2. Stop tun2socks (SIGTERM, then SIGKILL after grace period)
// 3. Remove loopback alias
// 4. The utun device is destroyed automatically when tun2socks exits
func (t *TunManager) Stop() error {
t.mu.Lock()
defer t.mu.Unlock()
var errs []string
// Signal the monitoring goroutine to stop.
select {
case <-t.done:
// Already closed.
default:
close(t.done)
}
// Step 1: Unload pf rules (best effort).
if err := t.unloadPFRulesLocked(); err != nil {
errs = append(errs, fmt.Sprintf("unload pf rules: %v", err))
}
// Step 2: Stop tun2socks.
if err := t.stopTun2Socks(); err != nil {
errs = append(errs, fmt.Sprintf("stop tun2socks: %v", err))
}
// Step 3: Remove loopback alias (best effort).
if err := t.removeLoopbackAlias(); err != nil {
errs = append(errs, fmt.Sprintf("remove loopback alias: %v", err))
}
if len(errs) > 0 {
return fmt.Errorf("stop errors: %s", strings.Join(errs, "; "))
}
t.logDebug("Tunnel stack stopped")
return nil
}
// TunDevice returns the name of the utun device (e.g., "utun7").
// Returns an empty string if the tunnel has not been started.
func (t *TunManager) TunDevice() string {
t.mu.Lock()
defer t.mu.Unlock()
return t.tunDevice
}
// LoadPFRules loads pf anchor rules that route traffic from the given sandbox
// group through the utun device. The rules:
// - Route all TCP from the sandbox group through the utun interface
// - Redirect DNS (UDP port 53) from the sandbox group to the local DNS relay
//
// This requires root privileges and an active pf firewall.
func (t *TunManager) LoadPFRules(sandboxGroup string) error {
t.mu.Lock()
defer t.mu.Unlock()
if t.tunDevice == "" {
return fmt.Errorf("tunnel not started, no device available")
}
// Ensure the anchor reference exists in the main pf.conf.
if err := t.ensureAnchorInPFConf(); err != nil {
return fmt.Errorf("failed to ensure pf anchor: %w", err)
}
// Build the anchor rules. pf requires strict ordering:
// translation (rdr) before filtering (pass).
// Note: macOS pf does not support "group" in rdr rules, so DNS
// redirection uses a two-step approach:
// 1. rdr on lo0 — redirects DNS arriving on loopback to our relay
// 2. pass out route-to lo0 — sends sandbox group's DNS to loopback
// 3. pass out route-to utun — sends sandbox group's TCP through tunnel
rules := fmt.Sprintf(
"rdr on lo0 proto udp from any to any port 53 -> %s port %s\n"+
"pass out on !lo0 route-to (lo0 127.0.0.1) proto udp from any to any port 53 group %s\n"+
"pass out route-to (%s %s) proto tcp from any to any group %s\n",
dnsRelayIP, dnsRelayPort,
sandboxGroup,
t.tunDevice, tunIP, sandboxGroup,
)
t.logDebug("Loading pf rules into anchor %s:\n%s", t.pfAnchor, rules)
// Load the rules into the anchor.
//nolint:gosec // arguments are controlled internal constants, not user input
cmd := exec.Command("pfctl", "-a", t.pfAnchor, "-f", "-")
cmd.Stdin = strings.NewReader(rules)
cmd.Stderr = os.Stderr
if output, err := cmd.Output(); err != nil {
return fmt.Errorf("pfctl load rules failed: %w (output: %s)", err, string(output))
}
// Enable pf if it is not already enabled.
if err := t.enablePF(); err != nil {
// Non-fatal: pf may already be enabled.
t.logDebug("Warning: failed to enable pf (may already be active): %v", err)
}
t.logDebug("pf rules loaded for group %s on %s", sandboxGroup, t.tunDevice)
return nil
}
// UnloadPFRules removes the pf rules from the anchor.
func (t *TunManager) UnloadPFRules() error {
t.mu.Lock()
defer t.mu.Unlock()
return t.unloadPFRulesLocked()
}
// startTun2Socks launches the tun2socks process with "-device utun" so that it
// auto-creates a utun device. The device name is discovered by scanning tun2socks
// stderr output for the utunN identifier.
func (t *TunManager) startTun2Socks() error {
//nolint:gosec // tun2socksPath is an internal path, not user input
cmd := exec.Command(t.tun2socksPath, "-device", "utun", "-proxy", t.proxyURL)
// Capture both stdout and stderr to discover the device name.
// tun2socks may log the device name on either stream depending on version.
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start tun2socks: %w", err)
}
t.tun2socksCmd = cmd
// Read both stdout and stderr to discover the utun device name.
// tun2socks logs the device name shortly after startup
// (e.g., "level=INFO msg=[STACK] tun://utun7 <-> ...").
deviceCh := make(chan string, 2) // buffered for both goroutines
stderrLines := make(chan string, 100)
// scanPipe scans lines from a pipe, looking for the utun device name.
scanPipe := func(pipe io.Reader, label string) {
scanner := bufio.NewScanner(pipe)
for scanner.Scan() {
line := scanner.Text()
fmt.Fprintf(os.Stderr, "[greywall:tun] tun2socks(%s): %s\n", label, line) //nolint:gosec // logging tun2socks output
if match := utunDevicePattern.FindString(line); match != "" {
select {
case deviceCh <- match:
default:
// Already found by the other pipe.
}
}
select {
case stderrLines <- line:
default:
}
}
}
go scanPipe(stderrPipe, "stderr")
go scanPipe(stdoutPipe, "stdout")
// Wait for the device name with a timeout.
select {
case device := <-deviceCh:
if device == "" {
t.logDebug("Empty device from tun2socks output, trying ifconfig")
device, err = t.discoverUtunFromIfconfig()
if err != nil {
_ = cmd.Process.Kill()
return fmt.Errorf("failed to discover utun device: %w", err)
}
}
t.tunDevice = device
case <-time.After(10 * time.Second):
// Timeout: try ifconfig fallback.
t.logDebug("Timeout waiting for tun2socks device name, trying ifconfig")
device, err := t.discoverUtunFromIfconfig()
if err != nil {
_ = cmd.Process.Kill()
return fmt.Errorf("tun2socks did not report device name within timeout: %w", err)
}
t.tunDevice = device
}
t.logDebug("tun2socks started (pid=%d, device=%s)", cmd.Process.Pid, t.tunDevice)
// Monitor tun2socks in the background.
go t.monitorTun2Socks(stderrLines)
return nil
}
// discoverUtunFromIfconfig runs ifconfig and looks for a utun device. This is
// used as a fallback when we cannot parse the device name from tun2socks output.
func (t *TunManager) discoverUtunFromIfconfig() (string, error) {
out, err := exec.Command("ifconfig").Output()
if err != nil {
return "", fmt.Errorf("ifconfig failed: %w", err)
}
// Look for utun interfaces. We scan for lines starting with "utunN:"
// and return the highest-numbered one (most recently created).
ifPattern := regexp.MustCompile(`^(utun\d+):`)
var lastDevice string
for _, line := range strings.Split(string(out), "\n") {
if m := ifPattern.FindStringSubmatch(line); m != nil {
lastDevice = m[1]
}
}
if lastDevice == "" {
return "", fmt.Errorf("no utun device found in ifconfig output")
}
return lastDevice, nil
}
// monitorTun2Socks watches the tun2socks process and logs if it exits unexpectedly.
func (t *TunManager) monitorTun2Socks(stderrLines <-chan string) {
if t.tun2socksCmd == nil || t.tun2socksCmd.Process == nil {
return
}
// Drain any remaining stderr lines.
go func() {
for range stderrLines {
// Already logged in the scanner goroutine when debug is on.
}
}()
err := t.tun2socksCmd.Wait()
select {
case <-t.done:
// Expected shutdown.
t.logDebug("tun2socks exited (expected shutdown)")
default:
// Unexpected exit.
fmt.Fprintf(os.Stderr, "[greywall:tun] ERROR: tun2socks exited unexpectedly: %v\n", err)
}
}
// stopTun2Socks sends SIGTERM to the tun2socks process and waits for it to exit.
// If it does not exit within the grace period, SIGKILL is sent.
func (t *TunManager) stopTun2Socks() error {
if t.tun2socksCmd == nil || t.tun2socksCmd.Process == nil {
return nil
}
t.logDebug("Stopping tun2socks (pid=%d)", t.tun2socksCmd.Process.Pid)
// Send SIGTERM.
if err := t.tun2socksCmd.Process.Signal(os.Interrupt); err != nil {
// Process may have already exited.
t.logDebug("SIGTERM failed (process may have exited): %v", err)
t.tun2socksCmd = nil
return nil
}
// Wait for exit with a timeout.
exited := make(chan error, 1)
go func() {
// Wait may have already been called by the monitor goroutine,
// in which case this will return immediately.
exited <- t.tun2socksCmd.Wait()
}()
select {
case err := <-exited:
if err != nil {
t.logDebug("tun2socks exited with: %v", err)
}
case <-time.After(tun2socksStopGracePeriod):
t.logDebug("tun2socks did not exit after SIGTERM, sending SIGKILL")
_ = t.tun2socksCmd.Process.Kill()
}
t.tun2socksCmd = nil
return nil
}
// configureInterface sets up the utun interface with a point-to-point IP address.
func (t *TunManager) configureInterface() error {
t.logDebug("Configuring interface %s with IP %s", t.tunDevice, tunIP)
//nolint:gosec // tunDevice and tunIP are controlled internal values
cmd := exec.Command("ifconfig", t.tunDevice, tunIP, tunIP, "up")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("ifconfig %s failed: %w (output: %s)", t.tunDevice, err, string(output))
}
return nil
}
// addLoopbackAlias adds an alias IP on lo0 for the DNS relay.
func (t *TunManager) addLoopbackAlias() error {
t.logDebug("Adding loopback alias %s on lo0", dnsRelayIP)
cmd := exec.Command("ifconfig", "lo0", "alias", dnsRelayIP, "up")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("ifconfig lo0 alias failed: %w (output: %s)", err, string(output))
}
return nil
}
// removeLoopbackAlias removes the DNS relay alias from lo0.
func (t *TunManager) removeLoopbackAlias() error {
t.logDebug("Removing loopback alias %s from lo0", dnsRelayIP)
cmd := exec.Command("ifconfig", "lo0", "-alias", dnsRelayIP)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("ifconfig lo0 -alias failed: %w (output: %s)", err, string(output))
}
return nil
}
// ensureAnchorInPFConf checks whether the pf anchor reference exists in
// /etc/pf.conf. If not, it inserts the anchor lines at the correct positions
// (pf requires strict ordering: rdr-anchor before anchor, both before load anchor)
// and reloads the main ruleset.
func (t *TunManager) ensureAnchorInPFConf() error {
const pfConfPath = "/etc/pf.conf"
anchorLine := fmt.Sprintf(`anchor "%s"`, t.pfAnchor)
rdrAnchorLine := fmt.Sprintf(`rdr-anchor "%s"`, t.pfAnchor)
data, err := os.ReadFile(pfConfPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", pfConfPath, err)
}
lines := strings.Split(string(data), "\n")
// Line-level presence check avoids substring false positives
// (e.g. 'anchor "X"' matching inside 'rdr-anchor "X"').
hasAnchor := false
hasRdrAnchor := false
lastRdrIdx := -1
lastAnchorIdx := -1
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == rdrAnchorLine {
hasRdrAnchor = true
}
if trimmed == anchorLine {
hasAnchor = true
}
if strings.HasPrefix(trimmed, "rdr-anchor ") {
lastRdrIdx = i
}
// Standalone "anchor" lines — not rdr-anchor, nat-anchor, etc.
if strings.HasPrefix(trimmed, "anchor ") {
lastAnchorIdx = i
}
}
if hasAnchor && hasRdrAnchor {
t.logDebug("pf anchor already present in %s", pfConfPath)
return nil
}
t.logDebug("Adding pf anchor to %s", pfConfPath)
// Insert at the correct positions. Process in reverse index order
// so earlier insertions don't shift later indices.
var result []string
for i, line := range lines {
result = append(result, line)
if !hasRdrAnchor && i == lastRdrIdx {
result = append(result, rdrAnchorLine)
}
if !hasAnchor && i == lastAnchorIdx {
result = append(result, anchorLine)
}
}
// Fallback: if no existing rdr-anchor/anchor found, append at end.
if !hasRdrAnchor && lastRdrIdx == -1 {
result = append(result, rdrAnchorLine)
}
if !hasAnchor && lastAnchorIdx == -1 {
result = append(result, anchorLine)
}
newContent := strings.Join(result, "\n")
//nolint:gosec // pf.conf must be writable by root; the daemon runs as root
if err := os.WriteFile(pfConfPath, []byte(newContent), 0o644); err != nil {
return fmt.Errorf("failed to write %s: %w", pfConfPath, err)
}
// Reload the main pf.conf so the anchor reference is recognized.
//nolint:gosec // pfConfPath is a constant
reloadCmd := exec.Command("pfctl", "-f", pfConfPath)
if output, err := reloadCmd.CombinedOutput(); err != nil {
return fmt.Errorf("pfctl reload failed: %w (output: %s)", err, string(output))
}
t.logDebug("pf anchor added and pf.conf reloaded")
return nil
}
// enablePF enables the pf firewall if it is not already active.
func (t *TunManager) enablePF() error {
// Check current pf status.
out, err := exec.Command("pfctl", "-s", "info").CombinedOutput()
if err == nil && strings.Contains(string(out), "Status: Enabled") {
t.logDebug("pf is already enabled")
return nil
}
t.logDebug("Enabling pf")
cmd := exec.Command("pfctl", "-e")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("pfctl -e failed: %w (output: %s)", err, string(output))
}
return nil
}
// unloadPFRulesLocked flushes all rules from the pf anchor. Must be called
// with t.mu held.
func (t *TunManager) unloadPFRulesLocked() error {
t.logDebug("Flushing pf anchor %s", t.pfAnchor)
//nolint:gosec // pfAnchor is a controlled internal constant
cmd := exec.Command("pfctl", "-a", t.pfAnchor, "-F", "all")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("pfctl flush anchor failed: %w (output: %s)", err, string(output))
}
return nil
}
// removeAnchorFromPFConf removes greywall anchor lines from /etc/pf.conf.
// Called during uninstall to clean up.
func removeAnchorFromPFConf(debug bool) error {
const pfConfPath = "/etc/pf.conf"
anchorLine := fmt.Sprintf(`anchor "%s"`, pfAnchorName)
rdrAnchorLine := fmt.Sprintf(`rdr-anchor "%s"`, pfAnchorName)
data, err := os.ReadFile(pfConfPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", pfConfPath, err)
}
lines := strings.Split(string(data), "\n")
var filtered []string
removed := 0
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == anchorLine || trimmed == rdrAnchorLine {
removed++
continue
}
filtered = append(filtered, line)
}
if removed == 0 {
logDebug(debug, "No pf anchor lines to remove from %s", pfConfPath)
return nil
}
//nolint:gosec // pf.conf must be writable by root; the daemon runs as root
if err := os.WriteFile(pfConfPath, []byte(strings.Join(filtered, "\n")), 0o644); err != nil {
return fmt.Errorf("failed to write %s: %w", pfConfPath, err)
}
logDebug(debug, "Removed %d pf anchor lines from %s", removed, pfConfPath)
return nil
}
// logDebug writes a debug message to stderr with the [greywall:tun] prefix.
func (t *TunManager) logDebug(format string, args ...interface{}) {
if t.debug {
fmt.Fprintf(os.Stderr, "[greywall:tun] "+format+"\n", args...)
}
}

View File

@@ -0,0 +1,38 @@
//go:build !darwin
package daemon
import "fmt"
// TunManager is a stub for non-macOS platforms.
type TunManager struct{}
// NewTunManager returns an error on non-macOS platforms.
func NewTunManager(tun2socksPath string, proxyURL string, debug bool) *TunManager {
return &TunManager{}
}
// Start returns an error on non-macOS platforms.
func (t *TunManager) Start() error {
return fmt.Errorf("tun manager is only available on macOS")
}
// Stop returns an error on non-macOS platforms.
func (t *TunManager) Stop() error {
return fmt.Errorf("tun manager is only available on macOS")
}
// TunDevice returns an empty string on non-macOS platforms.
func (t *TunManager) TunDevice() string {
return ""
}
// LoadPFRules returns an error on non-macOS platforms.
func (t *TunManager) LoadPFRules(sandboxUser string) error {
return fmt.Errorf("pf rules are only available on macOS")
}
// UnloadPFRules returns an error on non-macOS platforms.
func (t *TunManager) UnloadPFRules() error {
return fmt.Errorf("pf rules are only available on macOS")
}

View File

@@ -45,6 +45,8 @@ type MacOSSandboxParams struct {
AllowPty bool
AllowGitConfig bool
Shell string
DaemonMode bool // When true, pf handles network routing; Seatbelt allows network-outbound
DaemonSocketPath string // Daemon socket to deny access to from sandboxed process
}
// GlobToRegex converts a glob pattern to a regex for macOS sandbox profiles.
@@ -422,8 +424,8 @@ func GenerateSandboxProfile(params MacOSSandboxParams) string {
// Header
profile.WriteString("(version 1)\n")
profile.WriteString(fmt.Sprintf("(deny default (with message %q))\n\n", logTag))
profile.WriteString(fmt.Sprintf("; LogTag: %s\n\n", logTag))
fmt.Fprintf(&profile, "(deny default (with message %q))\n\n", logTag)
fmt.Fprintf(&profile, "; LogTag: %s\n\n", logTag)
// Essential permissions - based on Chrome sandbox policy
profile.WriteString(`; Essential permissions - based on Chrome sandbox policy
@@ -566,9 +568,27 @@ func GenerateSandboxProfile(params MacOSSandboxParams) string {
// Network rules
profile.WriteString("; Network\n")
if !params.NeedsNetworkRestriction {
switch {
case params.DaemonMode:
// In daemon mode, pf handles network routing: all traffic from the
// _greywall user is routed through utun → tun2socks → proxy.
// Seatbelt must allow network-outbound so packets reach pf.
// The proxy allowlist is enforced by the external SOCKS5 proxy.
profile.WriteString("(allow network-outbound)\n")
// Allow local binding for servers if configured.
if params.AllowLocalBinding {
profile.WriteString(`(allow network-bind (local ip "localhost:*"))
(allow network-inbound (local ip "localhost:*"))
`)
}
// Explicitly deny access to the daemon socket to prevent the
// sandboxed process from manipulating daemon sessions.
if params.DaemonSocketPath != "" {
fmt.Fprintf(&profile, "(deny network-outbound (remote unix-socket (path-literal %s)))\n", escapePath(params.DaemonSocketPath))
}
case !params.NeedsNetworkRestriction:
profile.WriteString("(allow network*)\n")
} else {
default:
if params.AllowLocalBinding {
// Allow binding and inbound connections on localhost (for servers)
profile.WriteString(`(allow network-bind (local ip "localhost:*"))
@@ -586,14 +606,13 @@ func GenerateSandboxProfile(params MacOSSandboxParams) string {
} else if len(params.AllowUnixSockets) > 0 {
for _, socketPath := range params.AllowUnixSockets {
normalized := NormalizePath(socketPath)
profile.WriteString(fmt.Sprintf("(allow network* (subpath %s))\n", escapePath(normalized)))
fmt.Fprintf(&profile, "(allow network* (subpath %s))\n", escapePath(normalized))
}
}
// Allow outbound to the external proxy host:port
if params.ProxyHost != "" && params.ProxyPort != "" {
profile.WriteString(fmt.Sprintf(`(allow network-outbound (remote ip "%s:%s"))
`, params.ProxyHost, params.ProxyPort))
fmt.Fprintf(&profile, "(allow network-outbound (remote ip \"%s:%s\"))\n", params.ProxyHost, params.ProxyPort)
}
}
profile.WriteString("\n")
@@ -631,7 +650,9 @@ func GenerateSandboxProfile(params MacOSSandboxParams) string {
}
// WrapCommandMacOS wraps a command with macOS sandbox restrictions.
func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, debug bool) (string, error) {
// When daemonSession is non-nil, the command runs as the _greywall user
// with network-outbound allowed (pf routes traffic through utun → proxy).
func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, daemonSession *DaemonSession, debug bool) (string, error) {
cwd, _ := os.Getwd()
// Build allow paths: default + configured
@@ -657,9 +678,13 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de
}
}
// Determine if we're using daemon-mode (transparent proxying via pf + utun)
daemonMode := daemonSession != nil
// Restrict network unless proxy is configured to an external host
// If no proxy: block all outbound. If proxy: allow outbound only to proxy.
needsNetworkRestriction := true
// In daemon mode, network restriction is handled by pf, not Seatbelt.
needsNetworkRestriction := !daemonMode
params := MacOSSandboxParams{
Command: command,
@@ -679,6 +704,8 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de
WriteDenyPaths: cfg.Filesystem.DenyWrite,
AllowPty: cfg.AllowPty,
AllowGitConfig: cfg.Filesystem.AllowGitConfig,
DaemonMode: daemonMode,
DaemonSocketPath: "/var/run/greywall.sock",
}
if debug && len(exposedPorts) > 0 {
@@ -687,6 +714,10 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de
if debug && allowLocalBinding && !allowLocalOutbound {
fmt.Fprintf(os.Stderr, "[greywall:macos] Blocking localhost outbound (AllowLocalOutbound=false)\n")
}
if debug && daemonMode {
fmt.Fprintf(os.Stderr, "[greywall:macos] Daemon mode: transparent proxying via pf + utun (group=%s, device=%s)\n",
daemonSession.SandboxGroup, daemonSession.TunDevice)
}
profile := GenerateSandboxProfile(params)
@@ -700,14 +731,23 @@ func WrapCommandMacOS(cfg *config.Config, command string, exposedPorts []int, de
return "", fmt.Errorf("shell %q not found: %w", shell, err)
}
proxyEnvs := GenerateProxyEnvVars(cfg.Network.ProxyURL)
// Build the command
// env VAR1=val1 VAR2=val2 sandbox-exec -p 'profile' shell -c 'command'
var parts []string
if daemonMode {
// In daemon mode: run as the real user but with EGID=_greywall via sudo.
// pf routes all traffic from group _greywall through utun → tun2socks → proxy.
// Using -u #<uid> preserves the user's identity (home dir, SSH keys, etc.)
// while -g _greywall sets the effective GID for pf matching.
uid := fmt.Sprintf("#%d", os.Getuid())
parts = append(parts, "sudo", "-u", uid, "-g", daemonSession.SandboxGroup,
"sandbox-exec", "-p", profile, shellPath, "-c", command)
} else {
// Non-daemon mode: use proxy env vars for best-effort proxying.
proxyEnvs := GenerateProxyEnvVars(cfg.Network.ProxyURL)
parts = append(parts, "env")
parts = append(parts, proxyEnvs...)
parts = append(parts, "sandbox-exec", "-p", profile, shellPath, "-c", command)
}
return ShellQuote(parts), nil
}

View File

@@ -5,9 +5,20 @@ import (
"os"
"gitea.app.monadical.io/monadical/greywall/internal/config"
"gitea.app.monadical.io/monadical/greywall/internal/daemon"
"gitea.app.monadical.io/monadical/greywall/internal/platform"
)
// DaemonSession holds the state from an active daemon session on macOS.
// When a daemon session is active, traffic is routed through pf + utun
// instead of using env-var proxy settings.
type DaemonSession struct {
SessionID string
TunDevice string
SandboxUser string
SandboxGroup string
}
// Manager handles sandbox initialization and command wrapping.
type Manager struct {
config *config.Config
@@ -22,6 +33,9 @@ type Manager struct {
learning bool // learning mode: permissive sandbox with strace
straceLogPath string // host-side temp file for strace output
commandName string // name of the command being learned
// macOS daemon session fields
daemonClient *daemon.Client
daemonSession *DaemonSession
}
// NewManager creates a new sandbox manager.
@@ -63,11 +77,36 @@ func (m *Manager) Initialize() error {
return fmt.Errorf("sandbox is not supported on platform: %s", platform.Detect())
}
// On macOS, the daemon is required for transparent proxying.
// Without it, env-var proxying is unreliable (only works for tools that
// honor HTTP_PROXY) and gives users a false sense of security.
if platform.Detect() == platform.MacOS && m.config.Network.ProxyURL != "" {
client := daemon.NewClient(daemon.DefaultSocketPath, m.debug)
if !client.IsRunning() {
return fmt.Errorf("greywall daemon is not running (required for macOS network sandboxing)\n\n" +
" Install and start: sudo greywall daemon install\n" +
" Check status: greywall daemon status")
}
m.logDebug("Daemon is running, requesting session")
resp, err := client.CreateSession(m.config.Network.ProxyURL, m.config.Network.DnsAddr)
if err != nil {
return fmt.Errorf("failed to create daemon session: %w", err)
}
m.daemonClient = client
m.daemonSession = &DaemonSession{
SessionID: resp.SessionID,
TunDevice: resp.TunDevice,
SandboxUser: resp.SandboxUser,
SandboxGroup: resp.SandboxGroup,
}
m.logDebug("Daemon session created: id=%s device=%s user=%s group=%s", resp.SessionID, resp.TunDevice, resp.SandboxUser, resp.SandboxGroup)
}
// On Linux, set up proxy bridge and tun2socks if proxy is configured
if platform.Detect() == platform.Linux {
if m.config.Network.ProxyURL != "" {
// Extract embedded tun2socks binary
tun2socksPath, err := extractTun2Socks()
tun2socksPath, err := ExtractTun2Socks()
if err != nil {
m.logDebug("Failed to extract tun2socks: %v (will fall back to env-var proxying)", err)
} else {
@@ -148,7 +187,7 @@ func (m *Manager) WrapCommand(command string) (string, error) {
plat := platform.Detect()
switch plat {
case platform.MacOS:
return WrapCommandMacOS(m.config, command, m.exposedPorts, m.debug)
return WrapCommandMacOS(m.config, command, m.exposedPorts, m.daemonSession, m.debug)
case platform.Linux:
if m.learning {
return m.wrapCommandLearning(command)
@@ -201,6 +240,16 @@ func (m *Manager) GenerateLearnedTemplate(cmdName string) (string, error) {
// Cleanup stops the proxies and cleans up resources.
func (m *Manager) Cleanup() {
// Destroy macOS daemon session if active.
if m.daemonClient != nil && m.daemonSession != nil {
m.logDebug("Destroying daemon session %s", m.daemonSession.SessionID)
if err := m.daemonClient.DestroySession(m.daemonSession.SessionID); err != nil {
m.logDebug("Warning: failed to destroy daemon session: %v", err)
}
m.daemonSession = nil
m.daemonClient = nil
}
if m.reverseBridge != nil {
m.reverseBridge.Cleanup()
}