//go:build darwin || linux package daemon import ( "bytes" "net" "sync" "testing" "time" ) // startMockDNSServer starts a UDP server that echoes back whatever it receives, // prefixed with "RESP:" to distinguish responses from queries. // Returns the server's address and a cleanup function. func startMockDNSServer(t *testing.T) (string, func()) { t.Helper() addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to resolve address: %v", err) } conn, err := net.ListenUDP("udp", addr) if err != nil { t.Fatalf("Failed to start mock DNS server: %v", err) } done := make(chan struct{}) go func() { buf := make([]byte, maxDNSPacketSize) for { n, remoteAddr, err := conn.ReadFromUDP(buf) if err != nil { select { case <-done: return default: continue } } // Echo back with "RESP:" prefix. resp := append([]byte("RESP:"), buf[:n]...) _, _ = conn.WriteToUDP(resp, remoteAddr) // best-effort in test } }() cleanup := func() { close(done) _ = conn.Close() } return conn.LocalAddr().String(), cleanup } // startSilentDNSServer starts a UDP server that accepts connections but never // responds, simulating an upstream timeout. func startSilentDNSServer(t *testing.T) (string, func()) { t.Helper() addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to resolve address: %v", err) } conn, err := net.ListenUDP("udp", addr) if err != nil { t.Fatalf("Failed to start silent DNS server: %v", err) } cleanup := func() { _ = conn.Close() } return conn.LocalAddr().String(), cleanup } func TestDNSRelay_ForwardPacket(t *testing.T) { // Start a mock upstream DNS server. upstreamAddr, cleanupUpstream := startMockDNSServer(t) defer cleanupUpstream() // Create and start the relay. relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, true) if err != nil { t.Fatalf("Failed to create DNS relay: %v", err) } if err := relay.Start(); err != nil { t.Fatalf("Failed to start DNS relay: %v", err) } defer relay.Stop() // Send a query through the relay. clientConn, err := net.Dial("udp", relay.ListenAddr()) if err != nil { t.Fatalf("Failed to connect to relay: %v", err) } defer clientConn.Close() //nolint:errcheck // test cleanup query := []byte("test-dns-query") if _, err := clientConn.Write(query); err != nil { t.Fatalf("Failed to send query: %v", err) } // Read the response. if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { t.Fatalf("Failed to set read deadline: %v", err) } buf := make([]byte, maxDNSPacketSize) n, err := clientConn.Read(buf) if err != nil { t.Fatalf("Failed to read response: %v", err) } expected := append([]byte("RESP:"), query...) if !bytes.Equal(buf[:n], expected) { t.Errorf("Unexpected response: got %q, want %q", buf[:n], expected) } } func TestDNSRelay_UpstreamTimeout(t *testing.T) { // Start a silent upstream server that never responds. upstreamAddr, cleanupUpstream := startSilentDNSServer(t) defer cleanupUpstream() // Create and start the relay. relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, false) if err != nil { t.Fatalf("Failed to create DNS relay: %v", err) } if err := relay.Start(); err != nil { t.Fatalf("Failed to start DNS relay: %v", err) } defer relay.Stop() // Send a query through the relay. clientConn, err := net.Dial("udp", relay.ListenAddr()) if err != nil { t.Fatalf("Failed to connect to relay: %v", err) } defer clientConn.Close() //nolint:errcheck // test cleanup query := []byte("test-dns-timeout") if _, err := clientConn.Write(query); err != nil { t.Fatalf("Failed to send query: %v", err) } // The relay should not send back a response because upstream timed out. // Set a short deadline on the client side; we expect no data. if err := clientConn.SetReadDeadline(time.Now().Add(6 * time.Second)); err != nil { t.Fatalf("Failed to set read deadline: %v", err) } buf := make([]byte, maxDNSPacketSize) _, err = clientConn.Read(buf) if err == nil { t.Fatal("Expected timeout error reading from relay, but got a response") } // Verify it was a timeout error. netErr, ok := err.(net.Error) if !ok || !netErr.Timeout() { t.Fatalf("Expected timeout error, got: %v", err) } } func TestDNSRelay_ConcurrentQueries(t *testing.T) { // Start a mock upstream DNS server. upstreamAddr, cleanupUpstream := startMockDNSServer(t) defer cleanupUpstream() // Create and start the relay. relay, err := NewDNSRelay("127.0.0.1:0", upstreamAddr, false) if err != nil { t.Fatalf("Failed to create DNS relay: %v", err) } if err := relay.Start(); err != nil { t.Fatalf("Failed to start DNS relay: %v", err) } defer relay.Stop() const numQueries = 20 var wg sync.WaitGroup errors := make(chan error, numQueries) for i := range numQueries { wg.Add(1) go func(id int) { defer wg.Done() clientConn, err := net.Dial("udp", relay.ListenAddr()) if err != nil { errors <- err return } defer clientConn.Close() //nolint:errcheck // test cleanup query := []byte("concurrent-query-" + string(rune('A'+id))) //nolint:gosec // test uses small range 0-19, no overflow if _, err := clientConn.Write(query); err != nil { errors <- err return } if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { errors <- err return } buf := make([]byte, maxDNSPacketSize) n, err := clientConn.Read(buf) if err != nil { errors <- err return } expected := append([]byte("RESP:"), query...) if !bytes.Equal(buf[:n], expected) { errors <- &unexpectedResponseError{got: buf[:n], want: expected} } }(i) } wg.Wait() close(errors) for err := range errors { t.Errorf("Concurrent query error: %v", err) } } func TestDNSRelay_ListenAddr(t *testing.T) { // Use port 0 to get an ephemeral port. relay, err := NewDNSRelay("127.0.0.1:0", "1.1.1.1:53", false) if err != nil { t.Fatalf("Failed to create DNS relay: %v", err) } defer relay.Stop() addr := relay.ListenAddr() if addr == "" { t.Fatal("ListenAddr returned empty string") } host, port, err := net.SplitHostPort(addr) if err != nil { t.Fatalf("ListenAddr returned invalid address %q: %v", addr, err) } if host != "127.0.0.1" { t.Errorf("Expected host 127.0.0.1, got %q", host) } if port == "0" { t.Error("Expected assigned port, got 0") } } func TestNewDNSRelay_InvalidDNSAddr(t *testing.T) { tests := []struct { name string dnsAddr string }{ {"missing port", "1.1.1.1"}, {"empty string", ""}, {"empty host", ":53"}, {"empty port", "1.1.1.1:"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewDNSRelay("127.0.0.1:0", tt.dnsAddr, false) if err == nil { t.Errorf("Expected error for DNS address %q, got nil", tt.dnsAddr) } }) } } func TestNewDNSRelay_InvalidListenAddr(t *testing.T) { _, err := NewDNSRelay("invalid-addr", "1.1.1.1:53", false) if err == nil { t.Error("Expected error for invalid listen address, got nil") } } // unexpectedResponseError is used to report mismatched responses in concurrent tests. type unexpectedResponseError struct { got []byte want []byte } func (e *unexpectedResponseError) Error() string { return "unexpected response: got " + string(e.got) + ", want " + string(e.want) }