package sandbox import ( "fmt" "io" "net/http" "net/url" "testing" "time" "gitea.app.monadical.io/monadical/greywall/internal/config" ) func TestFilteringProxy_AllowedDomain(t *testing.T) { nc := &config.NetworkConfig{ AllowedDomains: []string{"httpbin.org"}, } fp, err := NewFilteringProxy(nc, false) if err != nil { t.Fatalf("NewFilteringProxy() error = %v", err) } defer fp.Shutdown() if fp.Port() == "" { t.Fatal("expected non-empty port") } if fp.Addr() == "" { t.Fatal("expected non-empty addr") } } func TestFilteringProxy_DeniedDomain_HTTP(t *testing.T) { nc := &config.NetworkConfig{ AllowedDomains: []string{"allowed.example.com"}, } fp, err := NewFilteringProxy(nc, false) if err != nil { t.Fatalf("NewFilteringProxy() error = %v", err) } defer fp.Shutdown() // Make a plain HTTP request to a denied domain through the proxy proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), }, Timeout: 5 * time.Second, } resp, err := client.Get("http://denied.example.com/test") if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusForbidden { body, _ := io.ReadAll(resp.Body) t.Errorf("expected 403 Forbidden, got %d: %s", resp.StatusCode, string(body)) } } func TestFilteringProxy_DeniedDomain_CONNECT(t *testing.T) { nc := &config.NetworkConfig{ AllowedDomains: []string{"allowed.example.com"}, } fp, err := NewFilteringProxy(nc, false) if err != nil { t.Fatalf("NewFilteringProxy() error = %v", err) } defer fp.Shutdown() // Make a CONNECT request to a denied domain proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), }, Timeout: 5 * time.Second, } // HTTPS triggers CONNECT method through the proxy _, err = client.Get("https://denied.example.com/test") if err == nil { t.Error("expected error for denied CONNECT, got nil") } // The error should indicate the proxy rejected the connection (403) } func TestFilteringProxy_DenyList_Only(t *testing.T) { nc := &config.NetworkConfig{ DeniedDomains: []string{"evil.com"}, } fp, err := NewFilteringProxy(nc, false) if err != nil { t.Fatalf("NewFilteringProxy() error = %v", err) } defer fp.Shutdown() proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), }, Timeout: 5 * time.Second, } // Denied domain should be blocked resp, err := client.Get("http://evil.com/test") if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusForbidden { t.Errorf("expected 403 for denied domain, got %d", resp.StatusCode) } } func TestFilteringProxy_WildcardAllow(t *testing.T) { nc := &config.NetworkConfig{ AllowedDomains: []string{"*"}, DeniedDomains: []string{"evil.com"}, } fp, err := NewFilteringProxy(nc, false) if err != nil { t.Fatalf("NewFilteringProxy() error = %v", err) } defer fp.Shutdown() proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr())) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), }, Timeout: 5 * time.Second, } // Denied domain should still be blocked even with wildcard allow resp, err := client.Get("http://evil.com/test") if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusForbidden { t.Errorf("expected 403 for denied domain with wildcard allow, got %d", resp.StatusCode) } } func TestFilteringProxy_Shutdown(t *testing.T) { nc := &config.NetworkConfig{ AllowedDomains: []string{"example.com"}, } fp, err := NewFilteringProxy(nc, false) if err != nil { t.Fatalf("NewFilteringProxy() error = %v", err) } // Shutdown should not panic fp.Shutdown() // Double shutdown should not panic fp.Shutdown() } func TestExtractHost(t *testing.T) { tests := []struct { input string want string }{ {"example.com:443", "example.com"}, {"example.com:80", "example.com"}, {"example.com", "example.com"}, {"127.0.0.1:8080", "127.0.0.1"}, {"[::1]:443", "::1"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got := extractHost(tt.input) if got != tt.want { t.Errorf("extractHost(%q) = %q, want %q", tt.input, got, tt.want) } }) } }