This repository has been archived on 2026-03-13. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
greywall/internal/sandbox/proxy.go
Jose B 6be1cf5620
Some checks failed
Build and test / Lint (pull_request) Failing after 1m3s
Build and test / Test (Linux) (pull_request) Failing after 39s
Build and test / Build (pull_request) Successful in 19s
feat: add domain-based outbound filtering with allowedDomains/deniedDomains
Add NetworkConfig.AllowedDomains and DeniedDomains fields for controlling
outbound connections by hostname. Deny rules are checked first (deny wins).
When AllowedDomains is set, only matching domains are permitted. When only
DeniedDomains is set, all domains except denied ones are allowed.

Implement FilteringProxy that wraps gost HTTP proxy with domain enforcement
via AllowConnect callback. Skip GreyHaven proxy/DNS defaults
2026-02-17 11:52:43 -05:00

246 lines
5.9 KiB
Go

package sandbox
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"gitea.app.monadical.io/monadical/greywall/internal/config"
)
// FilteringProxy is an HTTP CONNECT proxy that filters outbound connections by domain.
// It runs on the host and is the only outbound target the sandbox allows.
type FilteringProxy struct {
listener net.Listener
server *http.Server
network *config.NetworkConfig
debug bool
mu sync.Mutex
closed bool
}
// NewFilteringProxy creates and starts a new domain-filtering HTTP proxy.
// It listens on 127.0.0.1 with a random available port.
func NewFilteringProxy(network *config.NetworkConfig, debug bool) (*FilteringProxy, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
fp := &FilteringProxy{
listener: listener,
network: network,
debug: debug,
}
fp.server = &http.Server{
Handler: http.HandlerFunc(fp.serveHTTP),
ReadHeaderTimeout: 30 * time.Second,
}
go func() {
if err := fp.server.Serve(listener); err != nil && err != http.ErrServerClosed {
fp.logDebug("Proxy server error: %v", err)
}
}()
if debug {
fmt.Fprintf(os.Stderr, "[greywall:proxy] Filtering proxy started on %s\n", listener.Addr().String())
}
return fp, nil
}
func (fp *FilteringProxy) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
fp.handleConnect(w, r)
} else if r.Method == http.MethodGet && r.URL.Path == "/__greywall_dns" {
fp.handleDNS(w, r)
} else {
fp.handleHTTP(w, r)
}
}
// handleDNS resolves a hostname and returns the IP addresses as JSON.
// Used by the Node.js bootstrap to patch dns.lookup inside the sandbox.
func (fp *FilteringProxy) handleDNS(w http.ResponseWriter, r *http.Request) {
host := r.URL.Query().Get("host")
if host == "" {
w.Header().Set("Content-Type", "application/json")
http.Error(w, `{"error":"missing host parameter"}`, http.StatusBadRequest)
return
}
if !fp.network.IsDomainAllowed(host) {
fp.logDenied(host)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, `{"error":"domain denied: %s"}`, host)
return
}
addrs, err := net.LookupHost(host)
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"%s"}`, err.Error())
return
}
type addrEntry struct {
Address string `json:"address"`
Family int `json:"family"`
}
var entries []addrEntry
for _, addr := range addrs {
family := 4
if strings.Contains(addr, ":") {
family = 6
}
entries = append(entries, addrEntry{Address: addr, Family: family})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"addresses": entries,
})
}
func (fp *FilteringProxy) handleConnect(w http.ResponseWriter, r *http.Request) {
host := extractHost(r.Host)
if !fp.network.IsDomainAllowed(host) {
fp.logDenied(host)
http.Error(w, fmt.Sprintf("[greywall] domain denied: %s", host), http.StatusForbidden)
return
}
// Dial the target
target := r.Host
if !strings.Contains(target, ":") {
target = target + ":443"
}
destConn, err := net.DialTimeout("tcp", target, 10*time.Second)
if err != nil {
http.Error(w, fmt.Sprintf("[greywall] failed to connect to %s: %v", target, err), http.StatusBadGateway)
return
}
// Hijack the client connection
hijacker, ok := w.(http.Hijacker)
if !ok {
destConn.Close()
http.Error(w, "[greywall] hijacking not supported", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
clientConn, _, err := hijacker.Hijack()
if err != nil {
destConn.Close()
return
}
// Bidirectional copy
go func() {
defer destConn.Close()
defer clientConn.Close()
_, _ = io.Copy(destConn, clientConn)
}()
go func() {
defer destConn.Close()
defer clientConn.Close()
_, _ = io.Copy(clientConn, destConn)
}()
}
func (fp *FilteringProxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
host := extractHost(r.Host)
if !fp.network.IsDomainAllowed(host) {
fp.logDenied(host)
http.Error(w, fmt.Sprintf("[greywall] domain denied: %s", host), http.StatusForbidden)
return
}
// Forward the request
r.RequestURI = ""
resp, err := http.DefaultTransport.RoundTrip(r)
if err != nil {
http.Error(w, fmt.Sprintf("[greywall] failed to forward request: %v", err), http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Copy response headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
// Addr returns the listener address as a string (e.g. "127.0.0.1:12345").
func (fp *FilteringProxy) Addr() string {
return fp.listener.Addr().String()
}
// Port returns the listener port as a string.
func (fp *FilteringProxy) Port() string {
_, port, _ := net.SplitHostPort(fp.listener.Addr().String())
return port
}
// Shutdown gracefully stops the proxy.
func (fp *FilteringProxy) Shutdown() {
fp.mu.Lock()
defer fp.mu.Unlock()
if fp.closed {
return
}
fp.closed = true
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = fp.server.Shutdown(ctx)
if fp.debug {
fmt.Fprintf(os.Stderr, "[greywall:proxy] Filtering proxy stopped\n")
}
}
func (fp *FilteringProxy) logDenied(host string) {
fmt.Fprintf(os.Stderr, "\033[31m[greywall:proxy] domain denied: %s\033[0m\n", host)
}
func (fp *FilteringProxy) logDebug(format string, args ...interface{}) {
if fp.debug {
fmt.Fprintf(os.Stderr, "[greywall:proxy] "+format+"\n", args...)
}
}
// extractHost extracts the hostname from a host:port string, stripping the port.
func extractHost(hostport string) string {
host, _, err := net.SplitHostPort(hostport)
if err != nil {
// No port present
return hostport
}
return host
}