Add NetworkConfig.AllowedDomains and DeniedDomains fields for controlling outbound connections by hostname. Deny rules are checked first (deny wins). When AllowedDomains is set, only matching domains are permitted. When only DeniedDomains is set, all domains except denied ones are allowed. Implement FilteringProxy that wraps gost HTTP proxy with domain enforcement via AllowConnect callback. Skip GreyHaven proxy/DNS defaults
194 lines
4.4 KiB
Go
194 lines
4.4 KiB
Go
package sandbox
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.app.monadical.io/monadical/greywall/internal/config"
|
|
)
|
|
|
|
func TestFilteringProxy_AllowedDomain(t *testing.T) {
|
|
nc := &config.NetworkConfig{
|
|
AllowedDomains: []string{"httpbin.org"},
|
|
}
|
|
|
|
fp, err := NewFilteringProxy(nc, false)
|
|
if err != nil {
|
|
t.Fatalf("NewFilteringProxy() error = %v", err)
|
|
}
|
|
defer fp.Shutdown()
|
|
|
|
if fp.Port() == "" {
|
|
t.Fatal("expected non-empty port")
|
|
}
|
|
if fp.Addr() == "" {
|
|
t.Fatal("expected non-empty addr")
|
|
}
|
|
}
|
|
|
|
func TestFilteringProxy_DeniedDomain_HTTP(t *testing.T) {
|
|
nc := &config.NetworkConfig{
|
|
AllowedDomains: []string{"allowed.example.com"},
|
|
}
|
|
|
|
fp, err := NewFilteringProxy(nc, false)
|
|
if err != nil {
|
|
t.Fatalf("NewFilteringProxy() error = %v", err)
|
|
}
|
|
defer fp.Shutdown()
|
|
|
|
// Make a plain HTTP request to a denied domain through the proxy
|
|
proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr()))
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyURL(proxyURL),
|
|
},
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
resp, err := client.Get("http://denied.example.com/test")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Errorf("expected 403 Forbidden, got %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
}
|
|
|
|
func TestFilteringProxy_DeniedDomain_CONNECT(t *testing.T) {
|
|
nc := &config.NetworkConfig{
|
|
AllowedDomains: []string{"allowed.example.com"},
|
|
}
|
|
|
|
fp, err := NewFilteringProxy(nc, false)
|
|
if err != nil {
|
|
t.Fatalf("NewFilteringProxy() error = %v", err)
|
|
}
|
|
defer fp.Shutdown()
|
|
|
|
// Make a CONNECT request to a denied domain
|
|
proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr()))
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyURL(proxyURL),
|
|
},
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// HTTPS triggers CONNECT method through the proxy
|
|
_, err = client.Get("https://denied.example.com/test")
|
|
if err == nil {
|
|
t.Error("expected error for denied CONNECT, got nil")
|
|
}
|
|
// The error should indicate the proxy rejected the connection (403)
|
|
}
|
|
|
|
func TestFilteringProxy_DenyList_Only(t *testing.T) {
|
|
nc := &config.NetworkConfig{
|
|
DeniedDomains: []string{"evil.com"},
|
|
}
|
|
|
|
fp, err := NewFilteringProxy(nc, false)
|
|
if err != nil {
|
|
t.Fatalf("NewFilteringProxy() error = %v", err)
|
|
}
|
|
defer fp.Shutdown()
|
|
|
|
proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr()))
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyURL(proxyURL),
|
|
},
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// Denied domain should be blocked
|
|
resp, err := client.Get("http://evil.com/test")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("expected 403 for denied domain, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestFilteringProxy_WildcardAllow(t *testing.T) {
|
|
nc := &config.NetworkConfig{
|
|
AllowedDomains: []string{"*"},
|
|
DeniedDomains: []string{"evil.com"},
|
|
}
|
|
|
|
fp, err := NewFilteringProxy(nc, false)
|
|
if err != nil {
|
|
t.Fatalf("NewFilteringProxy() error = %v", err)
|
|
}
|
|
defer fp.Shutdown()
|
|
|
|
proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", fp.Addr()))
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyURL(proxyURL),
|
|
},
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// Denied domain should still be blocked even with wildcard allow
|
|
resp, err := client.Get("http://evil.com/test")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("expected 403 for denied domain with wildcard allow, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestFilteringProxy_Shutdown(t *testing.T) {
|
|
nc := &config.NetworkConfig{
|
|
AllowedDomains: []string{"example.com"},
|
|
}
|
|
|
|
fp, err := NewFilteringProxy(nc, false)
|
|
if err != nil {
|
|
t.Fatalf("NewFilteringProxy() error = %v", err)
|
|
}
|
|
|
|
// Shutdown should not panic
|
|
fp.Shutdown()
|
|
|
|
// Double shutdown should not panic
|
|
fp.Shutdown()
|
|
}
|
|
|
|
func TestExtractHost(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want string
|
|
}{
|
|
{"example.com:443", "example.com"},
|
|
{"example.com:80", "example.com"},
|
|
{"example.com", "example.com"},
|
|
{"127.0.0.1:8080", "127.0.0.1"},
|
|
{"[::1]:443", "::1"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
got := extractHost(tt.input)
|
|
if got != tt.want {
|
|
t.Errorf("extractHost(%q) = %q, want %q", tt.input, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|