From 549c504585838450831b5d59690556c9f43e0cc5 Mon Sep 17 00:00:00 2001 From: JY Tan Date: Thu, 18 Dec 2025 17:50:04 -0800 Subject: [PATCH] Add unit tests --- internal/config/config_test.go | 293 +++++++++++++++++++++++++++++ internal/proxy/http_test.go | 273 +++++++++++++++++++++++++++ internal/proxy/socks_test.go | 130 +++++++++++++ internal/sandbox/dangerous_test.go | 170 +++++++++++++++++ internal/sandbox/utils_test.go | 278 +++++++++++++++++++++++++++ 5 files changed, 1144 insertions(+) create mode 100644 internal/config/config_test.go create mode 100644 internal/proxy/http_test.go create mode 100644 internal/proxy/socks_test.go create mode 100644 internal/sandbox/dangerous_test.go create mode 100644 internal/sandbox/utils_test.go diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..f56bb5b --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,293 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestValidateDomainPattern(t *testing.T) { + tests := []struct { + name string + pattern string + wantErr bool + }{ + // Valid patterns + {"valid domain", "example.com", false}, + {"valid subdomain", "api.example.com", false}, + {"valid wildcard", "*.example.com", false}, + {"valid wildcard subdomain", "*.api.example.com", false}, + {"localhost", "localhost", false}, + + // Invalid patterns + {"protocol included", "https://example.com", true}, + {"path included", "example.com/path", true}, + {"port included", "example.com:443", true}, + {"wildcard too broad", "*.com", true}, + {"invalid wildcard position", "example.*.com", true}, + {"trailing wildcard", "example.com.*", true}, + {"leading dot", ".example.com", true}, + {"trailing dot", "example.com.", true}, + {"no TLD", "example", true}, + {"empty wildcard domain part", "*.", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDomainPattern(tt.pattern) + if (err != nil) != tt.wantErr { + t.Errorf("validateDomainPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr) + } + }) + } +} + +func TestMatchesDomain(t *testing.T) { + tests := []struct { + name string + hostname string + pattern string + want bool + }{ + // Exact matches + {"exact match", "example.com", "example.com", true}, + {"exact match case insensitive", "Example.COM", "example.com", true}, + {"exact no match", "other.com", "example.com", false}, + + // Wildcard matches + {"wildcard match subdomain", "api.example.com", "*.example.com", true}, + {"wildcard match deep subdomain", "deep.api.example.com", "*.example.com", true}, + {"wildcard no match base domain", "example.com", "*.example.com", false}, + {"wildcard no match different domain", "api.other.com", "*.example.com", false}, + {"wildcard case insensitive", "API.Example.COM", "*.example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchesDomain(tt.hostname, tt.pattern) + if got != tt.want { + t.Errorf("MatchesDomain(%q, %q) = %v, want %v", tt.hostname, tt.pattern, got, tt.want) + } + }) + } +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config Config + wantErr bool + }{ + { + name: "valid empty config", + config: Config{}, + wantErr: false, + }, + { + name: "valid config with domains", + config: Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"example.com", "*.github.com"}, + DeniedDomains: []string{"blocked.com"}, + }, + }, + wantErr: false, + }, + { + name: "invalid allowed domain", + config: Config{ + Network: NetworkConfig{ + AllowedDomains: []string{"https://example.com"}, + }, + }, + wantErr: true, + }, + { + name: "invalid denied domain", + config: Config{ + Network: NetworkConfig{ + DeniedDomains: []string{"*.com"}, + }, + }, + wantErr: true, + }, + { + name: "empty denyRead path", + config: Config{ + Filesystem: FilesystemConfig{ + DenyRead: []string{""}, + }, + }, + wantErr: true, + }, + { + name: "empty allowWrite path", + config: Config{ + Filesystem: FilesystemConfig{ + AllowWrite: []string{""}, + }, + }, + wantErr: true, + }, + { + name: "empty denyWrite path", + config: Config{ + Filesystem: FilesystemConfig{ + DenyWrite: []string{""}, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefault(t *testing.T) { + cfg := Default() + if cfg == nil { + t.Fatal("Default() returned nil") + } + if cfg.Network.AllowedDomains == nil { + t.Error("AllowedDomains should not be nil") + } + if cfg.Network.DeniedDomains == nil { + t.Error("DeniedDomains should not be nil") + } + if cfg.Filesystem.DenyRead == nil { + t.Error("DenyRead should not be nil") + } + if cfg.Filesystem.AllowWrite == nil { + t.Error("AllowWrite should not be nil") + } + if cfg.Filesystem.DenyWrite == nil { + t.Error("DenyWrite should not be nil") + } +} + +func TestLoad(t *testing.T) { + // Create temp directory for test files + tmpDir := t.TempDir() + + tests := []struct { + name string + content string + setup func(string) string // returns path + wantNil bool + wantErr bool + checkConfig func(*testing.T, *Config) + }{ + { + name: "nonexistent file", + setup: func(dir string) string { return filepath.Join(dir, "nonexistent.json") }, + wantNil: true, + wantErr: false, + }, + { + name: "empty file", + content: "", + setup: func(dir string) string { + path := filepath.Join(dir, "empty.json") + _ = os.WriteFile(path, []byte(""), 0o644) + return path + }, + wantNil: true, + wantErr: false, + }, + { + name: "whitespace only file", + content: " \n\t ", + setup: func(dir string) string { + path := filepath.Join(dir, "whitespace.json") + _ = os.WriteFile(path, []byte(" \n\t "), 0o644) + return path + }, + wantNil: true, + wantErr: false, + }, + { + name: "valid config", + setup: func(dir string) string { + path := filepath.Join(dir, "valid.json") + content := `{"network":{"allowedDomains":["example.com"]}}` + _ = os.WriteFile(path, []byte(content), 0o644) + return path + }, + wantNil: false, + wantErr: false, + checkConfig: func(t *testing.T, cfg *Config) { + if len(cfg.Network.AllowedDomains) != 1 { + t.Errorf("expected 1 allowed domain, got %d", len(cfg.Network.AllowedDomains)) + } + if cfg.Network.AllowedDomains[0] != "example.com" { + t.Errorf("expected example.com, got %s", cfg.Network.AllowedDomains[0]) + } + }, + }, + { + name: "invalid JSON", + setup: func(dir string) string { + path := filepath.Join(dir, "invalid.json") + _ = os.WriteFile(path, []byte("{invalid json}"), 0o644) + return path + }, + wantNil: false, + wantErr: true, + }, + { + name: "invalid domain in config", + setup: func(dir string) string { + path := filepath.Join(dir, "invalid_domain.json") + content := `{"network":{"allowedDomains":["*.com"]}}` + _ = os.WriteFile(path, []byte(content), 0o644) + return path + }, + wantNil: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := tt.setup(tmpDir) + cfg, err := Load(path) + + if (err != nil) != tt.wantErr { + t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantNil && cfg != nil { + t.Error("Load() expected nil config") + return + } + + if !tt.wantNil && !tt.wantErr && cfg == nil { + t.Error("Load() returned nil config unexpectedly") + return + } + + if tt.checkConfig != nil && cfg != nil { + tt.checkConfig(t, cfg) + } + }) + } +} + +func TestDefaultConfigPath(t *testing.T) { + path := DefaultConfigPath() + if path == "" { + t.Error("DefaultConfigPath() returned empty string") + } + // Should end with .fence.json + if filepath.Base(path) != ".fence.json" { + t.Errorf("DefaultConfigPath() = %q, expected to end with .fence.json", path) + } +} diff --git a/internal/proxy/http_test.go b/internal/proxy/http_test.go new file mode 100644 index 0000000..6d80c50 --- /dev/null +++ b/internal/proxy/http_test.go @@ -0,0 +1,273 @@ +package proxy + +import ( + "net/http" + "net/url" + "testing" + + "github.com/Use-Tusk/fence/internal/config" +) + +func TestTruncateURL(t *testing.T) { + tests := []struct { + name string + url string + maxLen int + want string + }{ + {"short url", "https://example.com", 50, "https://example.com"}, + {"exact length", "https://example.com", 19, "https://example.com"}, + {"needs truncation", "https://example.com/very/long/path/to/resource", 30, "https://example.com/very/lo..."}, + {"empty url", "", 50, ""}, + {"very short max", "https://example.com", 10, "https:/..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncateURL(tt.url, tt.maxLen) + if got != tt.want { + t.Errorf("truncateURL(%q, %d) = %q, want %q", tt.url, tt.maxLen, got, tt.want) + } + }) + } +} + +func TestGetHostFromRequest(t *testing.T) { + tests := []struct { + name string + host string + urlStr string + wantHost string + }{ + { + name: "host header only", + host: "example.com", + urlStr: "/path", + wantHost: "example.com", + }, + { + name: "host header with port", + host: "example.com:8080", + urlStr: "/path", + wantHost: "example.com", + }, + { + name: "full URL overrides host", + host: "other.com", + urlStr: "http://example.com/path", + wantHost: "example.com", + }, + { + name: "url with port", + host: "other.com", + urlStr: "http://example.com:9000/path", + wantHost: "example.com", + }, + { + name: "ipv6 host", + host: "[::1]:8080", + urlStr: "/path", + wantHost: "[::1]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsedURL, _ := url.Parse(tt.urlStr) + req := &http.Request{ + Host: tt.host, + URL: parsedURL, + } + + got := GetHostFromRequest(req) + if got != tt.wantHost { + t.Errorf("GetHostFromRequest() = %q, want %q", got, tt.wantHost) + } + }) + } +} + +func TestCreateDomainFilter(t *testing.T) { + tests := []struct { + name string + cfg *config.Config + host string + port int + allowed bool + }{ + { + name: "nil config denies all", + cfg: nil, + host: "example.com", + port: 443, + allowed: false, + }, + { + name: "allowed domain", + cfg: &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{"example.com"}, + }, + }, + host: "example.com", + port: 443, + allowed: true, + }, + { + name: "denied domain takes precedence", + cfg: &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{"example.com"}, + DeniedDomains: []string{"example.com"}, + }, + }, + host: "example.com", + port: 443, + allowed: false, + }, + { + name: "wildcard allowed", + cfg: &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{"*.example.com"}, + }, + }, + host: "api.example.com", + port: 443, + allowed: true, + }, + { + name: "wildcard denied", + cfg: &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{"*.example.com"}, + DeniedDomains: []string{"*.blocked.example.com"}, + }, + }, + host: "api.blocked.example.com", + port: 443, + allowed: false, + }, + { + name: "unmatched domain denied", + cfg: &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{"example.com"}, + }, + }, + host: "other.com", + port: 443, + allowed: false, + }, + { + name: "empty allowed list denies all", + cfg: &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{}, + }, + }, + host: "example.com", + port: 443, + allowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filter := CreateDomainFilter(tt.cfg, false) + got := filter(tt.host, tt.port) + if got != tt.allowed { + t.Errorf("CreateDomainFilter() filter(%q, %d) = %v, want %v", tt.host, tt.port, got, tt.allowed) + } + }) + } +} + +func TestCreateDomainFilterCaseInsensitive(t *testing.T) { + cfg := &config.Config{ + Network: config.NetworkConfig{ + AllowedDomains: []string{"Example.COM"}, + }, + } + + filter := CreateDomainFilter(cfg, false) + + tests := []struct { + host string + allowed bool + }{ + {"example.com", true}, + {"EXAMPLE.COM", true}, + {"Example.Com", true}, + } + + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + got := filter(tt.host, 443) + if got != tt.allowed { + t.Errorf("filter(%q) = %v, want %v", tt.host, got, tt.allowed) + } + }) + } +} + +func TestNewHTTPProxy(t *testing.T) { + filter := func(host string, port int) bool { return true } + + tests := []struct { + name string + debug bool + monitor bool + }{ + {"default", false, false}, + {"debug mode", true, false}, + {"monitor mode", false, true}, + {"both modes", true, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proxy := NewHTTPProxy(filter, tt.debug, tt.monitor) + if proxy == nil { + t.Error("NewHTTPProxy() returned nil") + } + if proxy.debug != tt.debug { + t.Errorf("debug = %v, want %v", proxy.debug, tt.debug) + } + if proxy.monitor != tt.monitor { + t.Errorf("monitor = %v, want %v", proxy.monitor, tt.monitor) + } + }) + } +} + +func TestHTTPProxyStartStop(t *testing.T) { + filter := func(host string, port int) bool { return true } + proxy := NewHTTPProxy(filter, false, false) + + port, err := proxy.Start() + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + if port <= 0 { + t.Errorf("Start() returned invalid port: %d", port) + } + + if proxy.Port() != port { + t.Errorf("Port() = %d, want %d", proxy.Port(), port) + } + + if err := proxy.Stop(); err != nil { + t.Errorf("Stop() error = %v", err) + } +} + +func TestHTTPProxyPortBeforeStart(t *testing.T) { + filter := func(host string, port int) bool { return true } + proxy := NewHTTPProxy(filter, false, false) + + if proxy.Port() != 0 { + t.Errorf("Port() before Start() = %d, want 0", proxy.Port()) + } +} diff --git a/internal/proxy/socks_test.go b/internal/proxy/socks_test.go new file mode 100644 index 0000000..b123c48 --- /dev/null +++ b/internal/proxy/socks_test.go @@ -0,0 +1,130 @@ +package proxy + +import ( + "context" + "net" + "testing" + + "github.com/things-go/go-socks5" + "github.com/things-go/go-socks5/statute" +) + +func TestFenceRuleSetAllow(t *testing.T) { + tests := []struct { + name string + fqdn string + ip net.IP + port int + allowed bool + }{ + { + name: "allow by FQDN", + fqdn: "allowed.com", + port: 443, + allowed: true, + }, + { + name: "deny by FQDN", + fqdn: "blocked.com", + port: 443, + allowed: false, + }, + { + name: "fallback to IP when FQDN empty", + fqdn: "", + ip: net.ParseIP("1.2.3.4"), + port: 80, + allowed: false, + }, + { + name: "allow with IP fallback", + fqdn: "", + ip: net.ParseIP("127.0.0.1"), + port: 8080, + allowed: true, + }, + } + + filter := func(host string, port int) bool { + return host == "allowed.com" || host == "127.0.0.1" + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := &fenceRuleSet{filter: filter, debug: false, monitor: false} + req := &socks5.Request{ + DestAddr: &statute.AddrSpec{ + FQDN: tt.fqdn, + IP: tt.ip, + Port: tt.port, + }, + } + + _, allowed := rs.Allow(context.Background(), req) + if allowed != tt.allowed { + t.Errorf("Allow() = %v, want %v", allowed, tt.allowed) + } + }) + } +} + +func TestNewSOCKSProxy(t *testing.T) { + filter := func(host string, port int) bool { return true } + + tests := []struct { + name string + debug bool + monitor bool + }{ + {"default", false, false}, + {"debug mode", true, false}, + {"monitor mode", false, true}, + {"both modes", true, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proxy := NewSOCKSProxy(filter, tt.debug, tt.monitor) + if proxy == nil { + t.Error("NewSOCKSProxy() returned nil") + } + if proxy.debug != tt.debug { + t.Errorf("debug = %v, want %v", proxy.debug, tt.debug) + } + if proxy.monitor != tt.monitor { + t.Errorf("monitor = %v, want %v", proxy.monitor, tt.monitor) + } + }) + } +} + +func TestSOCKSProxyStartStop(t *testing.T) { + filter := func(host string, port int) bool { return true } + proxy := NewSOCKSProxy(filter, false, false) + + port, err := proxy.Start() + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + if port <= 0 { + t.Errorf("Start() returned invalid port: %d", port) + } + + if proxy.Port() != port { + t.Errorf("Port() = %d, want %d", proxy.Port(), port) + } + + if err := proxy.Stop(); err != nil { + t.Errorf("Stop() error = %v", err) + } +} + +func TestSOCKSProxyPortBeforeStart(t *testing.T) { + filter := func(host string, port int) bool { return true } + proxy := NewSOCKSProxy(filter, false, false) + + if proxy.Port() != 0 { + t.Errorf("Port() before Start() = %d, want 0", proxy.Port()) + } +} diff --git a/internal/sandbox/dangerous_test.go b/internal/sandbox/dangerous_test.go new file mode 100644 index 0000000..b67b827 --- /dev/null +++ b/internal/sandbox/dangerous_test.go @@ -0,0 +1,170 @@ +package sandbox + +import ( + "path/filepath" + "slices" + "strings" + "testing" +) + +func TestGetDefaultWritePaths(t *testing.T) { + paths := GetDefaultWritePaths() + + if len(paths) == 0 { + t.Error("GetDefaultWritePaths() returned empty slice") + } + + essentialPaths := []string{"/dev/stdout", "/dev/stderr", "/dev/null", "/tmp/fence"} + for _, essential := range essentialPaths { + found := slices.Contains(paths, essential) + if !found { + t.Errorf("GetDefaultWritePaths() missing essential path %q", essential) + } + } +} + +func TestGetMandatoryDenyPatterns(t *testing.T) { + cwd := "/home/user/project" + + tests := []struct { + name string + cwd string + allowGitConfig bool + shouldContain []string + shouldNotContain []string + }{ + { + name: "with git config denied", + cwd: cwd, + allowGitConfig: false, + shouldContain: []string{ + filepath.Join(cwd, ".gitconfig"), + filepath.Join(cwd, ".bashrc"), + filepath.Join(cwd, ".zshrc"), + filepath.Join(cwd, ".git/hooks"), + filepath.Join(cwd, ".git/config"), + "**/.gitconfig", + "**/.bashrc", + "**/.git/hooks/**", + "**/.git/config", + }, + }, + { + name: "with git config allowed", + cwd: cwd, + allowGitConfig: true, + shouldContain: []string{ + filepath.Join(cwd, ".gitconfig"), + filepath.Join(cwd, ".git/hooks"), + "**/.git/hooks/**", + }, + shouldNotContain: []string{ + filepath.Join(cwd, ".git/config"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patterns := GetMandatoryDenyPatterns(tt.cwd, tt.allowGitConfig) + + for _, expected := range tt.shouldContain { + found := slices.Contains(patterns, expected) + if !found { + t.Errorf("GetMandatoryDenyPatterns() missing pattern %q", expected) + } + } + + for _, notExpected := range tt.shouldNotContain { + found := slices.Contains(patterns, notExpected) + if found { + t.Errorf("GetMandatoryDenyPatterns() should not contain %q when allowGitConfig=%v", notExpected, tt.allowGitConfig) + } + } + }) + } +} + +func TestGetMandatoryDenyPatternsContainsDangerousFiles(t *testing.T) { + cwd := "/test/project" + patterns := GetMandatoryDenyPatterns(cwd, false) + + // Each dangerous file should appear both as a cwd-relative path and as a glob pattern + for _, file := range DangerousFiles { + cwdPath := filepath.Join(cwd, file) + globPattern := "**/" + file + + foundCwd := false + foundGlob := false + + for _, p := range patterns { + if p == cwdPath { + foundCwd = true + } + if p == globPattern { + foundGlob = true + } + } + + if !foundCwd { + t.Errorf("Missing cwd-relative pattern for dangerous file %q", file) + } + if !foundGlob { + t.Errorf("Missing glob pattern for dangerous file %q", file) + } + } +} + +func TestGetMandatoryDenyPatternsContainsDangerousDirectories(t *testing.T) { + cwd := "/test/project" + patterns := GetMandatoryDenyPatterns(cwd, false) + + for _, dir := range DangerousDirectories { + cwdPath := filepath.Join(cwd, dir) + globPattern := "**/" + dir + "/**" + + foundCwd := false + foundGlob := false + + for _, p := range patterns { + if p == cwdPath { + foundCwd = true + } + if p == globPattern { + foundGlob = true + } + } + + if !foundCwd { + t.Errorf("Missing cwd-relative pattern for dangerous directory %q", dir) + } + if !foundGlob { + t.Errorf("Missing glob pattern for dangerous directory %q", dir) + } + } +} + +func TestGetMandatoryDenyPatternsGitHooksAlwaysBlocked(t *testing.T) { + cwd := "/test/project" + + // Git hooks should be blocked regardless of allowGitConfig + for _, allowGitConfig := range []bool{true, false} { + patterns := GetMandatoryDenyPatterns(cwd, allowGitConfig) + + foundHooksPath := false + foundHooksGlob := false + + for _, p := range patterns { + if p == filepath.Join(cwd, ".git/hooks") { + foundHooksPath = true + } + if strings.Contains(p, ".git/hooks") && strings.HasPrefix(p, "**") { + foundHooksGlob = true + } + } + + if !foundHooksPath || !foundHooksGlob { + t.Errorf("Git hooks should always be blocked (allowGitConfig=%v)", allowGitConfig) + } + } +} diff --git a/internal/sandbox/utils_test.go b/internal/sandbox/utils_test.go new file mode 100644 index 0000000..f4290be --- /dev/null +++ b/internal/sandbox/utils_test.go @@ -0,0 +1,278 @@ +package sandbox + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestContainsGlobChars(t *testing.T) { + tests := []struct { + pattern string + want bool + }{ + {"/path/to/file", false}, + {"/path/to/dir/", false}, + {"relative/path", false}, + {"/path/with/asterisk/*", true}, + {"/path/with/question?", true}, + {"/path/with/brackets[a-z]", true}, + {"/path/**/*.go", true}, + {"*.txt", true}, + {"file[0-9].txt", true}, + } + + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + got := ContainsGlobChars(tt.pattern) + if got != tt.want { + t.Errorf("ContainsGlobChars(%q) = %v, want %v", tt.pattern, got, tt.want) + } + }) + } +} + +func TestRemoveTrailingGlobSuffix(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/path/to/dir/**", "/path/to/dir"}, + {"/path/to/dir", "/path/to/dir"}, + {"/path/**/**", "/path/**"}, + {"/**", ""}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := RemoveTrailingGlobSuffix(tt.input) + if got != tt.want { + t.Errorf("RemoveTrailingGlobSuffix(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNormalizePath(t *testing.T) { + home, _ := os.UserHomeDir() + cwd, _ := os.Getwd() + + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "tilde alone", + input: "~", + want: home, + }, + { + name: "tilde with path", + input: "~/Documents", + want: filepath.Join(home, "Documents"), + }, + { + name: "absolute path", + input: "/usr/bin", + want: "/usr/bin", + }, + { + name: "relative dot path", + input: "./subdir", + want: filepath.Join(cwd, "subdir"), + }, + { + name: "relative parent path", + input: "../sibling", + want: filepath.Join(filepath.Dir(cwd), "sibling"), + }, + { + name: "glob pattern preserved", + input: "/path/**/*.go", + want: "/path/**/*.go", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizePath(tt.input) + + // For paths that involve symlink resolution, we just check the result is reasonable + if strings.Contains(tt.input, "**") || strings.Contains(tt.input, "*") { + if got != tt.want { + t.Errorf("NormalizePath(%q) = %q, want %q", tt.input, got, tt.want) + } + return + } + + // For tilde and relative paths, we check prefixes since symlinks may resolve differently + if tt.input == "~" { + if got != home && !strings.HasPrefix(got, "/") { + t.Errorf("NormalizePath(%q) = %q, expected home directory", tt.input, got) + } + } else if strings.HasPrefix(tt.input, "~/") { + if !strings.HasPrefix(got, home) && !strings.HasPrefix(got, "/") { + t.Errorf("NormalizePath(%q) = %q, expected path under home", tt.input, got) + } + } + }) + } +} + +func TestGenerateProxyEnvVars(t *testing.T) { + tests := []struct { + name string + httpPort int + socksPort int + wantEnvs []string + dontWant []string + }{ + { + name: "no ports", + httpPort: 0, + socksPort: 0, + wantEnvs: []string{ + "FENCE_SANDBOX=1", + "TMPDIR=/tmp/fence", + }, + dontWant: []string{ + "HTTP_PROXY=", + "HTTPS_PROXY=", + "ALL_PROXY=", + }, + }, + { + name: "http port only", + httpPort: 8080, + socksPort: 0, + wantEnvs: []string{ + "FENCE_SANDBOX=1", + "HTTP_PROXY=http://localhost:8080", + "HTTPS_PROXY=http://localhost:8080", + "http_proxy=http://localhost:8080", + "https_proxy=http://localhost:8080", + "NO_PROXY=", + "no_proxy=", + }, + dontWant: []string{ + "ALL_PROXY=", + "all_proxy=", + }, + }, + { + name: "socks port only", + httpPort: 0, + socksPort: 1080, + wantEnvs: []string{ + "FENCE_SANDBOX=1", + "ALL_PROXY=socks5h://localhost:1080", + "all_proxy=socks5h://localhost:1080", + "FTP_PROXY=socks5h://localhost:1080", + "GIT_SSH_COMMAND=", + }, + dontWant: []string{ + "HTTP_PROXY=", + "HTTPS_PROXY=", + }, + }, + { + name: "both ports", + httpPort: 8080, + socksPort: 1080, + wantEnvs: []string{ + "FENCE_SANDBOX=1", + "HTTP_PROXY=http://localhost:8080", + "HTTPS_PROXY=http://localhost:8080", + "ALL_PROXY=socks5h://localhost:1080", + "GIT_SSH_COMMAND=", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateProxyEnvVars(tt.httpPort, tt.socksPort) + + // Check expected env vars are present + for _, want := range tt.wantEnvs { + found := false + for _, env := range got { + if strings.HasPrefix(env, want) || env == want { + found = true + break + } + } + if !found { + t.Errorf("GenerateProxyEnvVars(%d, %d) missing %q", tt.httpPort, tt.socksPort, want) + } + } + + // Check unwanted env vars are not present + for _, dontWant := range tt.dontWant { + for _, env := range got { + if strings.HasPrefix(env, dontWant) { + t.Errorf("GenerateProxyEnvVars(%d, %d) should not contain %q, got %q", tt.httpPort, tt.socksPort, dontWant, env) + } + } + } + }) + } +} + +func TestEncodeSandboxedCommand(t *testing.T) { + tests := []struct { + name string + command string + }{ + {"simple command", "ls -la"}, + {"command with spaces", "grep -r 'pattern' /path/to/dir"}, + {"empty command", ""}, + {"special chars", "echo $HOME && ls | grep foo"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encoded := EncodeSandboxedCommand(tt.command) + if encoded == "" && tt.command != "" { + t.Error("EncodeSandboxedCommand returned empty string") + } + + // Roundtrip test + decoded, err := DecodeSandboxedCommand(encoded) + if err != nil { + t.Errorf("DecodeSandboxedCommand failed: %v", err) + } + + // Commands are truncated to 100 chars + expected := tt.command + if len(expected) > 100 { + expected = expected[:100] + } + if decoded != expected { + t.Errorf("Roundtrip failed: got %q, want %q", decoded, expected) + } + }) + } +} + +func TestEncodeSandboxedCommandTruncation(t *testing.T) { + // Test that long commands are truncated + longCommand := strings.Repeat("a", 200) + encoded := EncodeSandboxedCommand(longCommand) + decoded, _ := DecodeSandboxedCommand(encoded) + + if len(decoded) != 100 { + t.Errorf("Expected truncated command of 100 chars, got %d", len(decoded)) + } +} + +func TestDecodeSandboxedCommandInvalid(t *testing.T) { + _, err := DecodeSandboxedCommand("not-valid-base64!!!") + if err == nil { + t.Error("DecodeSandboxedCommand should fail on invalid base64") + } +}