Add unit tests
This commit is contained in:
293
internal/config/config_test.go
Normal file
293
internal/config/config_test.go
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateDomainPattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pattern string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// Valid patterns
|
||||||
|
{"valid domain", "example.com", false},
|
||||||
|
{"valid subdomain", "api.example.com", false},
|
||||||
|
{"valid wildcard", "*.example.com", false},
|
||||||
|
{"valid wildcard subdomain", "*.api.example.com", false},
|
||||||
|
{"localhost", "localhost", false},
|
||||||
|
|
||||||
|
// Invalid patterns
|
||||||
|
{"protocol included", "https://example.com", true},
|
||||||
|
{"path included", "example.com/path", true},
|
||||||
|
{"port included", "example.com:443", true},
|
||||||
|
{"wildcard too broad", "*.com", true},
|
||||||
|
{"invalid wildcard position", "example.*.com", true},
|
||||||
|
{"trailing wildcard", "example.com.*", true},
|
||||||
|
{"leading dot", ".example.com", true},
|
||||||
|
{"trailing dot", "example.com.", true},
|
||||||
|
{"no TLD", "example", true},
|
||||||
|
{"empty wildcard domain part", "*.", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateDomainPattern(tt.pattern)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateDomainPattern(%q) error = %v, wantErr %v", tt.pattern, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchesDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hostname string
|
||||||
|
pattern string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// Exact matches
|
||||||
|
{"exact match", "example.com", "example.com", true},
|
||||||
|
{"exact match case insensitive", "Example.COM", "example.com", true},
|
||||||
|
{"exact no match", "other.com", "example.com", false},
|
||||||
|
|
||||||
|
// Wildcard matches
|
||||||
|
{"wildcard match subdomain", "api.example.com", "*.example.com", true},
|
||||||
|
{"wildcard match deep subdomain", "deep.api.example.com", "*.example.com", true},
|
||||||
|
{"wildcard no match base domain", "example.com", "*.example.com", false},
|
||||||
|
{"wildcard no match different domain", "api.other.com", "*.example.com", false},
|
||||||
|
{"wildcard case insensitive", "API.Example.COM", "*.example.com", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MatchesDomain(tt.hostname, tt.pattern)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("MatchesDomain(%q, %q) = %v, want %v", tt.hostname, tt.pattern, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigValidate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid empty config",
|
||||||
|
config: Config{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config with domains",
|
||||||
|
config: Config{
|
||||||
|
Network: NetworkConfig{
|
||||||
|
AllowedDomains: []string{"example.com", "*.github.com"},
|
||||||
|
DeniedDomains: []string{"blocked.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid allowed domain",
|
||||||
|
config: Config{
|
||||||
|
Network: NetworkConfig{
|
||||||
|
AllowedDomains: []string{"https://example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid denied domain",
|
||||||
|
config: Config{
|
||||||
|
Network: NetworkConfig{
|
||||||
|
DeniedDomains: []string{"*.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty denyRead path",
|
||||||
|
config: Config{
|
||||||
|
Filesystem: FilesystemConfig{
|
||||||
|
DenyRead: []string{""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty allowWrite path",
|
||||||
|
config: Config{
|
||||||
|
Filesystem: FilesystemConfig{
|
||||||
|
AllowWrite: []string{""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty denyWrite path",
|
||||||
|
config: Config{
|
||||||
|
Filesystem: FilesystemConfig{
|
||||||
|
DenyWrite: []string{""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.Validate()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault(t *testing.T) {
|
||||||
|
cfg := Default()
|
||||||
|
if cfg == nil {
|
||||||
|
t.Fatal("Default() returned nil")
|
||||||
|
}
|
||||||
|
if cfg.Network.AllowedDomains == nil {
|
||||||
|
t.Error("AllowedDomains should not be nil")
|
||||||
|
}
|
||||||
|
if cfg.Network.DeniedDomains == nil {
|
||||||
|
t.Error("DeniedDomains should not be nil")
|
||||||
|
}
|
||||||
|
if cfg.Filesystem.DenyRead == nil {
|
||||||
|
t.Error("DenyRead should not be nil")
|
||||||
|
}
|
||||||
|
if cfg.Filesystem.AllowWrite == nil {
|
||||||
|
t.Error("AllowWrite should not be nil")
|
||||||
|
}
|
||||||
|
if cfg.Filesystem.DenyWrite == nil {
|
||||||
|
t.Error("DenyWrite should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
// Create temp directory for test files
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
setup func(string) string // returns path
|
||||||
|
wantNil bool
|
||||||
|
wantErr bool
|
||||||
|
checkConfig func(*testing.T, *Config)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nonexistent file",
|
||||||
|
setup: func(dir string) string { return filepath.Join(dir, "nonexistent.json") },
|
||||||
|
wantNil: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty file",
|
||||||
|
content: "",
|
||||||
|
setup: func(dir string) string {
|
||||||
|
path := filepath.Join(dir, "empty.json")
|
||||||
|
_ = os.WriteFile(path, []byte(""), 0o644)
|
||||||
|
return path
|
||||||
|
},
|
||||||
|
wantNil: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace only file",
|
||||||
|
content: " \n\t ",
|
||||||
|
setup: func(dir string) string {
|
||||||
|
path := filepath.Join(dir, "whitespace.json")
|
||||||
|
_ = os.WriteFile(path, []byte(" \n\t "), 0o644)
|
||||||
|
return path
|
||||||
|
},
|
||||||
|
wantNil: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
setup: func(dir string) string {
|
||||||
|
path := filepath.Join(dir, "valid.json")
|
||||||
|
content := `{"network":{"allowedDomains":["example.com"]}}`
|
||||||
|
_ = os.WriteFile(path, []byte(content), 0o644)
|
||||||
|
return path
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
wantErr: false,
|
||||||
|
checkConfig: func(t *testing.T, cfg *Config) {
|
||||||
|
if len(cfg.Network.AllowedDomains) != 1 {
|
||||||
|
t.Errorf("expected 1 allowed domain, got %d", len(cfg.Network.AllowedDomains))
|
||||||
|
}
|
||||||
|
if cfg.Network.AllowedDomains[0] != "example.com" {
|
||||||
|
t.Errorf("expected example.com, got %s", cfg.Network.AllowedDomains[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON",
|
||||||
|
setup: func(dir string) string {
|
||||||
|
path := filepath.Join(dir, "invalid.json")
|
||||||
|
_ = os.WriteFile(path, []byte("{invalid json}"), 0o644)
|
||||||
|
return path
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid domain in config",
|
||||||
|
setup: func(dir string) string {
|
||||||
|
path := filepath.Join(dir, "invalid_domain.json")
|
||||||
|
content := `{"network":{"allowedDomains":["*.com"]}}`
|
||||||
|
_ = os.WriteFile(path, []byte(content), 0o644)
|
||||||
|
return path
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := tt.setup(tmpDir)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantNil && cfg != nil {
|
||||||
|
t.Error("Load() expected nil config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantNil && !tt.wantErr && cfg == nil {
|
||||||
|
t.Error("Load() returned nil config unexpectedly")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkConfig != nil && cfg != nil {
|
||||||
|
tt.checkConfig(t, cfg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultConfigPath(t *testing.T) {
|
||||||
|
path := DefaultConfigPath()
|
||||||
|
if path == "" {
|
||||||
|
t.Error("DefaultConfigPath() returned empty string")
|
||||||
|
}
|
||||||
|
// Should end with .fence.json
|
||||||
|
if filepath.Base(path) != ".fence.json" {
|
||||||
|
t.Errorf("DefaultConfigPath() = %q, expected to end with .fence.json", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
273
internal/proxy/http_test.go
Normal file
273
internal/proxy/http_test.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Use-Tusk/fence/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTruncateURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
maxLen int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"short url", "https://example.com", 50, "https://example.com"},
|
||||||
|
{"exact length", "https://example.com", 19, "https://example.com"},
|
||||||
|
{"needs truncation", "https://example.com/very/long/path/to/resource", 30, "https://example.com/very/lo..."},
|
||||||
|
{"empty url", "", 50, ""},
|
||||||
|
{"very short max", "https://example.com", 10, "https:/..."},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := truncateURL(tt.url, tt.maxLen)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("truncateURL(%q, %d) = %q, want %q", tt.url, tt.maxLen, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetHostFromRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
urlStr string
|
||||||
|
wantHost string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "host header only",
|
||||||
|
host: "example.com",
|
||||||
|
urlStr: "/path",
|
||||||
|
wantHost: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host header with port",
|
||||||
|
host: "example.com:8080",
|
||||||
|
urlStr: "/path",
|
||||||
|
wantHost: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full URL overrides host",
|
||||||
|
host: "other.com",
|
||||||
|
urlStr: "http://example.com/path",
|
||||||
|
wantHost: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "url with port",
|
||||||
|
host: "other.com",
|
||||||
|
urlStr: "http://example.com:9000/path",
|
||||||
|
wantHost: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 host",
|
||||||
|
host: "[::1]:8080",
|
||||||
|
urlStr: "/path",
|
||||||
|
wantHost: "[::1]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parsedURL, _ := url.Parse(tt.urlStr)
|
||||||
|
req := &http.Request{
|
||||||
|
Host: tt.host,
|
||||||
|
URL: parsedURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
got := GetHostFromRequest(req)
|
||||||
|
if got != tt.wantHost {
|
||||||
|
t.Errorf("GetHostFromRequest() = %q, want %q", got, tt.wantHost)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateDomainFilter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg *config.Config
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
allowed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil config denies all",
|
||||||
|
cfg: nil,
|
||||||
|
host: "example.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allowed domain",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
host: "example.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denied domain takes precedence",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{"example.com"},
|
||||||
|
DeniedDomains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
host: "example.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allowed",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{"*.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
host: "api.example.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard denied",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{"*.example.com"},
|
||||||
|
DeniedDomains: []string{"*.blocked.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
host: "api.blocked.example.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unmatched domain denied",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
host: "other.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty allowed list denies all",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
host: "example.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
filter := CreateDomainFilter(tt.cfg, false)
|
||||||
|
got := filter(tt.host, tt.port)
|
||||||
|
if got != tt.allowed {
|
||||||
|
t.Errorf("CreateDomainFilter() filter(%q, %d) = %v, want %v", tt.host, tt.port, got, tt.allowed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateDomainFilterCaseInsensitive(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Network: config.NetworkConfig{
|
||||||
|
AllowedDomains: []string{"Example.COM"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := CreateDomainFilter(cfg, false)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
host string
|
||||||
|
allowed bool
|
||||||
|
}{
|
||||||
|
{"example.com", true},
|
||||||
|
{"EXAMPLE.COM", true},
|
||||||
|
{"Example.Com", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.host, func(t *testing.T) {
|
||||||
|
got := filter(tt.host, 443)
|
||||||
|
if got != tt.allowed {
|
||||||
|
t.Errorf("filter(%q) = %v, want %v", tt.host, got, tt.allowed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPProxy(t *testing.T) {
|
||||||
|
filter := func(host string, port int) bool { return true }
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
debug bool
|
||||||
|
monitor bool
|
||||||
|
}{
|
||||||
|
{"default", false, false},
|
||||||
|
{"debug mode", true, false},
|
||||||
|
{"monitor mode", false, true},
|
||||||
|
{"both modes", true, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
proxy := NewHTTPProxy(filter, tt.debug, tt.monitor)
|
||||||
|
if proxy == nil {
|
||||||
|
t.Error("NewHTTPProxy() returned nil")
|
||||||
|
}
|
||||||
|
if proxy.debug != tt.debug {
|
||||||
|
t.Errorf("debug = %v, want %v", proxy.debug, tt.debug)
|
||||||
|
}
|
||||||
|
if proxy.monitor != tt.monitor {
|
||||||
|
t.Errorf("monitor = %v, want %v", proxy.monitor, tt.monitor)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPProxyStartStop(t *testing.T) {
|
||||||
|
filter := func(host string, port int) bool { return true }
|
||||||
|
proxy := NewHTTPProxy(filter, false, false)
|
||||||
|
|
||||||
|
port, err := proxy.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Start() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port <= 0 {
|
||||||
|
t.Errorf("Start() returned invalid port: %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxy.Port() != port {
|
||||||
|
t.Errorf("Port() = %d, want %d", proxy.Port(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := proxy.Stop(); err != nil {
|
||||||
|
t.Errorf("Stop() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPProxyPortBeforeStart(t *testing.T) {
|
||||||
|
filter := func(host string, port int) bool { return true }
|
||||||
|
proxy := NewHTTPProxy(filter, false, false)
|
||||||
|
|
||||||
|
if proxy.Port() != 0 {
|
||||||
|
t.Errorf("Port() before Start() = %d, want 0", proxy.Port())
|
||||||
|
}
|
||||||
|
}
|
||||||
130
internal/proxy/socks_test.go
Normal file
130
internal/proxy/socks_test.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/things-go/go-socks5"
|
||||||
|
"github.com/things-go/go-socks5/statute"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFenceRuleSetAllow(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fqdn string
|
||||||
|
ip net.IP
|
||||||
|
port int
|
||||||
|
allowed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allow by FQDN",
|
||||||
|
fqdn: "allowed.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deny by FQDN",
|
||||||
|
fqdn: "blocked.com",
|
||||||
|
port: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback to IP when FQDN empty",
|
||||||
|
fqdn: "",
|
||||||
|
ip: net.ParseIP("1.2.3.4"),
|
||||||
|
port: 80,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow with IP fallback",
|
||||||
|
fqdn: "",
|
||||||
|
ip: net.ParseIP("127.0.0.1"),
|
||||||
|
port: 8080,
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := func(host string, port int) bool {
|
||||||
|
return host == "allowed.com" || host == "127.0.0.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rs := &fenceRuleSet{filter: filter, debug: false, monitor: false}
|
||||||
|
req := &socks5.Request{
|
||||||
|
DestAddr: &statute.AddrSpec{
|
||||||
|
FQDN: tt.fqdn,
|
||||||
|
IP: tt.ip,
|
||||||
|
Port: tt.port,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, allowed := rs.Allow(context.Background(), req)
|
||||||
|
if allowed != tt.allowed {
|
||||||
|
t.Errorf("Allow() = %v, want %v", allowed, tt.allowed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSOCKSProxy(t *testing.T) {
|
||||||
|
filter := func(host string, port int) bool { return true }
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
debug bool
|
||||||
|
monitor bool
|
||||||
|
}{
|
||||||
|
{"default", false, false},
|
||||||
|
{"debug mode", true, false},
|
||||||
|
{"monitor mode", false, true},
|
||||||
|
{"both modes", true, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
proxy := NewSOCKSProxy(filter, tt.debug, tt.monitor)
|
||||||
|
if proxy == nil {
|
||||||
|
t.Error("NewSOCKSProxy() returned nil")
|
||||||
|
}
|
||||||
|
if proxy.debug != tt.debug {
|
||||||
|
t.Errorf("debug = %v, want %v", proxy.debug, tt.debug)
|
||||||
|
}
|
||||||
|
if proxy.monitor != tt.monitor {
|
||||||
|
t.Errorf("monitor = %v, want %v", proxy.monitor, tt.monitor)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSOCKSProxyStartStop(t *testing.T) {
|
||||||
|
filter := func(host string, port int) bool { return true }
|
||||||
|
proxy := NewSOCKSProxy(filter, false, false)
|
||||||
|
|
||||||
|
port, err := proxy.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Start() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port <= 0 {
|
||||||
|
t.Errorf("Start() returned invalid port: %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxy.Port() != port {
|
||||||
|
t.Errorf("Port() = %d, want %d", proxy.Port(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := proxy.Stop(); err != nil {
|
||||||
|
t.Errorf("Stop() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSOCKSProxyPortBeforeStart(t *testing.T) {
|
||||||
|
filter := func(host string, port int) bool { return true }
|
||||||
|
proxy := NewSOCKSProxy(filter, false, false)
|
||||||
|
|
||||||
|
if proxy.Port() != 0 {
|
||||||
|
t.Errorf("Port() before Start() = %d, want 0", proxy.Port())
|
||||||
|
}
|
||||||
|
}
|
||||||
170
internal/sandbox/dangerous_test.go
Normal file
170
internal/sandbox/dangerous_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDefaultWritePaths(t *testing.T) {
|
||||||
|
paths := GetDefaultWritePaths()
|
||||||
|
|
||||||
|
if len(paths) == 0 {
|
||||||
|
t.Error("GetDefaultWritePaths() returned empty slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
essentialPaths := []string{"/dev/stdout", "/dev/stderr", "/dev/null", "/tmp/fence"}
|
||||||
|
for _, essential := range essentialPaths {
|
||||||
|
found := slices.Contains(paths, essential)
|
||||||
|
if !found {
|
||||||
|
t.Errorf("GetDefaultWritePaths() missing essential path %q", essential)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMandatoryDenyPatterns(t *testing.T) {
|
||||||
|
cwd := "/home/user/project"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cwd string
|
||||||
|
allowGitConfig bool
|
||||||
|
shouldContain []string
|
||||||
|
shouldNotContain []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with git config denied",
|
||||||
|
cwd: cwd,
|
||||||
|
allowGitConfig: false,
|
||||||
|
shouldContain: []string{
|
||||||
|
filepath.Join(cwd, ".gitconfig"),
|
||||||
|
filepath.Join(cwd, ".bashrc"),
|
||||||
|
filepath.Join(cwd, ".zshrc"),
|
||||||
|
filepath.Join(cwd, ".git/hooks"),
|
||||||
|
filepath.Join(cwd, ".git/config"),
|
||||||
|
"**/.gitconfig",
|
||||||
|
"**/.bashrc",
|
||||||
|
"**/.git/hooks/**",
|
||||||
|
"**/.git/config",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with git config allowed",
|
||||||
|
cwd: cwd,
|
||||||
|
allowGitConfig: true,
|
||||||
|
shouldContain: []string{
|
||||||
|
filepath.Join(cwd, ".gitconfig"),
|
||||||
|
filepath.Join(cwd, ".git/hooks"),
|
||||||
|
"**/.git/hooks/**",
|
||||||
|
},
|
||||||
|
shouldNotContain: []string{
|
||||||
|
filepath.Join(cwd, ".git/config"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
patterns := GetMandatoryDenyPatterns(tt.cwd, tt.allowGitConfig)
|
||||||
|
|
||||||
|
for _, expected := range tt.shouldContain {
|
||||||
|
found := slices.Contains(patterns, expected)
|
||||||
|
if !found {
|
||||||
|
t.Errorf("GetMandatoryDenyPatterns() missing pattern %q", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, notExpected := range tt.shouldNotContain {
|
||||||
|
found := slices.Contains(patterns, notExpected)
|
||||||
|
if found {
|
||||||
|
t.Errorf("GetMandatoryDenyPatterns() should not contain %q when allowGitConfig=%v", notExpected, tt.allowGitConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMandatoryDenyPatternsContainsDangerousFiles(t *testing.T) {
|
||||||
|
cwd := "/test/project"
|
||||||
|
patterns := GetMandatoryDenyPatterns(cwd, false)
|
||||||
|
|
||||||
|
// Each dangerous file should appear both as a cwd-relative path and as a glob pattern
|
||||||
|
for _, file := range DangerousFiles {
|
||||||
|
cwdPath := filepath.Join(cwd, file)
|
||||||
|
globPattern := "**/" + file
|
||||||
|
|
||||||
|
foundCwd := false
|
||||||
|
foundGlob := false
|
||||||
|
|
||||||
|
for _, p := range patterns {
|
||||||
|
if p == cwdPath {
|
||||||
|
foundCwd = true
|
||||||
|
}
|
||||||
|
if p == globPattern {
|
||||||
|
foundGlob = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundCwd {
|
||||||
|
t.Errorf("Missing cwd-relative pattern for dangerous file %q", file)
|
||||||
|
}
|
||||||
|
if !foundGlob {
|
||||||
|
t.Errorf("Missing glob pattern for dangerous file %q", file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMandatoryDenyPatternsContainsDangerousDirectories(t *testing.T) {
|
||||||
|
cwd := "/test/project"
|
||||||
|
patterns := GetMandatoryDenyPatterns(cwd, false)
|
||||||
|
|
||||||
|
for _, dir := range DangerousDirectories {
|
||||||
|
cwdPath := filepath.Join(cwd, dir)
|
||||||
|
globPattern := "**/" + dir + "/**"
|
||||||
|
|
||||||
|
foundCwd := false
|
||||||
|
foundGlob := false
|
||||||
|
|
||||||
|
for _, p := range patterns {
|
||||||
|
if p == cwdPath {
|
||||||
|
foundCwd = true
|
||||||
|
}
|
||||||
|
if p == globPattern {
|
||||||
|
foundGlob = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundCwd {
|
||||||
|
t.Errorf("Missing cwd-relative pattern for dangerous directory %q", dir)
|
||||||
|
}
|
||||||
|
if !foundGlob {
|
||||||
|
t.Errorf("Missing glob pattern for dangerous directory %q", dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMandatoryDenyPatternsGitHooksAlwaysBlocked(t *testing.T) {
|
||||||
|
cwd := "/test/project"
|
||||||
|
|
||||||
|
// Git hooks should be blocked regardless of allowGitConfig
|
||||||
|
for _, allowGitConfig := range []bool{true, false} {
|
||||||
|
patterns := GetMandatoryDenyPatterns(cwd, allowGitConfig)
|
||||||
|
|
||||||
|
foundHooksPath := false
|
||||||
|
foundHooksGlob := false
|
||||||
|
|
||||||
|
for _, p := range patterns {
|
||||||
|
if p == filepath.Join(cwd, ".git/hooks") {
|
||||||
|
foundHooksPath = true
|
||||||
|
}
|
||||||
|
if strings.Contains(p, ".git/hooks") && strings.HasPrefix(p, "**") {
|
||||||
|
foundHooksGlob = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundHooksPath || !foundHooksGlob {
|
||||||
|
t.Errorf("Git hooks should always be blocked (allowGitConfig=%v)", allowGitConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
278
internal/sandbox/utils_test.go
Normal file
278
internal/sandbox/utils_test.go
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContainsGlobChars(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
pattern string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"/path/to/file", false},
|
||||||
|
{"/path/to/dir/", false},
|
||||||
|
{"relative/path", false},
|
||||||
|
{"/path/with/asterisk/*", true},
|
||||||
|
{"/path/with/question?", true},
|
||||||
|
{"/path/with/brackets[a-z]", true},
|
||||||
|
{"/path/**/*.go", true},
|
||||||
|
{"*.txt", true},
|
||||||
|
{"file[0-9].txt", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.pattern, func(t *testing.T) {
|
||||||
|
got := ContainsGlobChars(tt.pattern)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ContainsGlobChars(%q) = %v, want %v", tt.pattern, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveTrailingGlobSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/path/to/dir/**", "/path/to/dir"},
|
||||||
|
{"/path/to/dir", "/path/to/dir"},
|
||||||
|
{"/path/**/**", "/path/**"},
|
||||||
|
{"/**", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := RemoveTrailingGlobSuffix(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("RemoveTrailingGlobSuffix(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePath(t *testing.T) {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
cwd, _ := os.Getwd()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tilde alone",
|
||||||
|
input: "~",
|
||||||
|
want: home,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tilde with path",
|
||||||
|
input: "~/Documents",
|
||||||
|
want: filepath.Join(home, "Documents"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "absolute path",
|
||||||
|
input: "/usr/bin",
|
||||||
|
want: "/usr/bin",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relative dot path",
|
||||||
|
input: "./subdir",
|
||||||
|
want: filepath.Join(cwd, "subdir"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relative parent path",
|
||||||
|
input: "../sibling",
|
||||||
|
want: filepath.Join(filepath.Dir(cwd), "sibling"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "glob pattern preserved",
|
||||||
|
input: "/path/**/*.go",
|
||||||
|
want: "/path/**/*.go",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NormalizePath(tt.input)
|
||||||
|
|
||||||
|
// For paths that involve symlink resolution, we just check the result is reasonable
|
||||||
|
if strings.Contains(tt.input, "**") || strings.Contains(tt.input, "*") {
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NormalizePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For tilde and relative paths, we check prefixes since symlinks may resolve differently
|
||||||
|
if tt.input == "~" {
|
||||||
|
if got != home && !strings.HasPrefix(got, "/") {
|
||||||
|
t.Errorf("NormalizePath(%q) = %q, expected home directory", tt.input, got)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(tt.input, "~/") {
|
||||||
|
if !strings.HasPrefix(got, home) && !strings.HasPrefix(got, "/") {
|
||||||
|
t.Errorf("NormalizePath(%q) = %q, expected path under home", tt.input, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateProxyEnvVars(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
httpPort int
|
||||||
|
socksPort int
|
||||||
|
wantEnvs []string
|
||||||
|
dontWant []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no ports",
|
||||||
|
httpPort: 0,
|
||||||
|
socksPort: 0,
|
||||||
|
wantEnvs: []string{
|
||||||
|
"FENCE_SANDBOX=1",
|
||||||
|
"TMPDIR=/tmp/fence",
|
||||||
|
},
|
||||||
|
dontWant: []string{
|
||||||
|
"HTTP_PROXY=",
|
||||||
|
"HTTPS_PROXY=",
|
||||||
|
"ALL_PROXY=",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http port only",
|
||||||
|
httpPort: 8080,
|
||||||
|
socksPort: 0,
|
||||||
|
wantEnvs: []string{
|
||||||
|
"FENCE_SANDBOX=1",
|
||||||
|
"HTTP_PROXY=http://localhost:8080",
|
||||||
|
"HTTPS_PROXY=http://localhost:8080",
|
||||||
|
"http_proxy=http://localhost:8080",
|
||||||
|
"https_proxy=http://localhost:8080",
|
||||||
|
"NO_PROXY=",
|
||||||
|
"no_proxy=",
|
||||||
|
},
|
||||||
|
dontWant: []string{
|
||||||
|
"ALL_PROXY=",
|
||||||
|
"all_proxy=",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "socks port only",
|
||||||
|
httpPort: 0,
|
||||||
|
socksPort: 1080,
|
||||||
|
wantEnvs: []string{
|
||||||
|
"FENCE_SANDBOX=1",
|
||||||
|
"ALL_PROXY=socks5h://localhost:1080",
|
||||||
|
"all_proxy=socks5h://localhost:1080",
|
||||||
|
"FTP_PROXY=socks5h://localhost:1080",
|
||||||
|
"GIT_SSH_COMMAND=",
|
||||||
|
},
|
||||||
|
dontWant: []string{
|
||||||
|
"HTTP_PROXY=",
|
||||||
|
"HTTPS_PROXY=",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both ports",
|
||||||
|
httpPort: 8080,
|
||||||
|
socksPort: 1080,
|
||||||
|
wantEnvs: []string{
|
||||||
|
"FENCE_SANDBOX=1",
|
||||||
|
"HTTP_PROXY=http://localhost:8080",
|
||||||
|
"HTTPS_PROXY=http://localhost:8080",
|
||||||
|
"ALL_PROXY=socks5h://localhost:1080",
|
||||||
|
"GIT_SSH_COMMAND=",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GenerateProxyEnvVars(tt.httpPort, tt.socksPort)
|
||||||
|
|
||||||
|
// Check expected env vars are present
|
||||||
|
for _, want := range tt.wantEnvs {
|
||||||
|
found := false
|
||||||
|
for _, env := range got {
|
||||||
|
if strings.HasPrefix(env, want) || env == want {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("GenerateProxyEnvVars(%d, %d) missing %q", tt.httpPort, tt.socksPort, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check unwanted env vars are not present
|
||||||
|
for _, dontWant := range tt.dontWant {
|
||||||
|
for _, env := range got {
|
||||||
|
if strings.HasPrefix(env, dontWant) {
|
||||||
|
t.Errorf("GenerateProxyEnvVars(%d, %d) should not contain %q, got %q", tt.httpPort, tt.socksPort, dontWant, env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeSandboxedCommand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
}{
|
||||||
|
{"simple command", "ls -la"},
|
||||||
|
{"command with spaces", "grep -r 'pattern' /path/to/dir"},
|
||||||
|
{"empty command", ""},
|
||||||
|
{"special chars", "echo $HOME && ls | grep foo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
encoded := EncodeSandboxedCommand(tt.command)
|
||||||
|
if encoded == "" && tt.command != "" {
|
||||||
|
t.Error("EncodeSandboxedCommand returned empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Roundtrip test
|
||||||
|
decoded, err := DecodeSandboxedCommand(encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("DecodeSandboxedCommand failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commands are truncated to 100 chars
|
||||||
|
expected := tt.command
|
||||||
|
if len(expected) > 100 {
|
||||||
|
expected = expected[:100]
|
||||||
|
}
|
||||||
|
if decoded != expected {
|
||||||
|
t.Errorf("Roundtrip failed: got %q, want %q", decoded, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeSandboxedCommandTruncation(t *testing.T) {
|
||||||
|
// Test that long commands are truncated
|
||||||
|
longCommand := strings.Repeat("a", 200)
|
||||||
|
encoded := EncodeSandboxedCommand(longCommand)
|
||||||
|
decoded, _ := DecodeSandboxedCommand(encoded)
|
||||||
|
|
||||||
|
if len(decoded) != 100 {
|
||||||
|
t.Errorf("Expected truncated command of 100 chars, got %d", len(decoded))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeSandboxedCommandInvalid(t *testing.T) {
|
||||||
|
_, err := DecodeSandboxedCommand("not-valid-base64!!!")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("DecodeSandboxedCommand should fail on invalid base64")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user