Add unit tests
This commit is contained in:
170
internal/sandbox/dangerous_test.go
Normal file
170
internal/sandbox/dangerous_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultWritePaths(t *testing.T) {
|
||||
paths := GetDefaultWritePaths()
|
||||
|
||||
if len(paths) == 0 {
|
||||
t.Error("GetDefaultWritePaths() returned empty slice")
|
||||
}
|
||||
|
||||
essentialPaths := []string{"/dev/stdout", "/dev/stderr", "/dev/null", "/tmp/fence"}
|
||||
for _, essential := range essentialPaths {
|
||||
found := slices.Contains(paths, essential)
|
||||
if !found {
|
||||
t.Errorf("GetDefaultWritePaths() missing essential path %q", essential)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMandatoryDenyPatterns(t *testing.T) {
|
||||
cwd := "/home/user/project"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cwd string
|
||||
allowGitConfig bool
|
||||
shouldContain []string
|
||||
shouldNotContain []string
|
||||
}{
|
||||
{
|
||||
name: "with git config denied",
|
||||
cwd: cwd,
|
||||
allowGitConfig: false,
|
||||
shouldContain: []string{
|
||||
filepath.Join(cwd, ".gitconfig"),
|
||||
filepath.Join(cwd, ".bashrc"),
|
||||
filepath.Join(cwd, ".zshrc"),
|
||||
filepath.Join(cwd, ".git/hooks"),
|
||||
filepath.Join(cwd, ".git/config"),
|
||||
"**/.gitconfig",
|
||||
"**/.bashrc",
|
||||
"**/.git/hooks/**",
|
||||
"**/.git/config",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with git config allowed",
|
||||
cwd: cwd,
|
||||
allowGitConfig: true,
|
||||
shouldContain: []string{
|
||||
filepath.Join(cwd, ".gitconfig"),
|
||||
filepath.Join(cwd, ".git/hooks"),
|
||||
"**/.git/hooks/**",
|
||||
},
|
||||
shouldNotContain: []string{
|
||||
filepath.Join(cwd, ".git/config"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
patterns := GetMandatoryDenyPatterns(tt.cwd, tt.allowGitConfig)
|
||||
|
||||
for _, expected := range tt.shouldContain {
|
||||
found := slices.Contains(patterns, expected)
|
||||
if !found {
|
||||
t.Errorf("GetMandatoryDenyPatterns() missing pattern %q", expected)
|
||||
}
|
||||
}
|
||||
|
||||
for _, notExpected := range tt.shouldNotContain {
|
||||
found := slices.Contains(patterns, notExpected)
|
||||
if found {
|
||||
t.Errorf("GetMandatoryDenyPatterns() should not contain %q when allowGitConfig=%v", notExpected, tt.allowGitConfig)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMandatoryDenyPatternsContainsDangerousFiles(t *testing.T) {
|
||||
cwd := "/test/project"
|
||||
patterns := GetMandatoryDenyPatterns(cwd, false)
|
||||
|
||||
// Each dangerous file should appear both as a cwd-relative path and as a glob pattern
|
||||
for _, file := range DangerousFiles {
|
||||
cwdPath := filepath.Join(cwd, file)
|
||||
globPattern := "**/" + file
|
||||
|
||||
foundCwd := false
|
||||
foundGlob := false
|
||||
|
||||
for _, p := range patterns {
|
||||
if p == cwdPath {
|
||||
foundCwd = true
|
||||
}
|
||||
if p == globPattern {
|
||||
foundGlob = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundCwd {
|
||||
t.Errorf("Missing cwd-relative pattern for dangerous file %q", file)
|
||||
}
|
||||
if !foundGlob {
|
||||
t.Errorf("Missing glob pattern for dangerous file %q", file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMandatoryDenyPatternsContainsDangerousDirectories(t *testing.T) {
|
||||
cwd := "/test/project"
|
||||
patterns := GetMandatoryDenyPatterns(cwd, false)
|
||||
|
||||
for _, dir := range DangerousDirectories {
|
||||
cwdPath := filepath.Join(cwd, dir)
|
||||
globPattern := "**/" + dir + "/**"
|
||||
|
||||
foundCwd := false
|
||||
foundGlob := false
|
||||
|
||||
for _, p := range patterns {
|
||||
if p == cwdPath {
|
||||
foundCwd = true
|
||||
}
|
||||
if p == globPattern {
|
||||
foundGlob = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundCwd {
|
||||
t.Errorf("Missing cwd-relative pattern for dangerous directory %q", dir)
|
||||
}
|
||||
if !foundGlob {
|
||||
t.Errorf("Missing glob pattern for dangerous directory %q", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMandatoryDenyPatternsGitHooksAlwaysBlocked(t *testing.T) {
|
||||
cwd := "/test/project"
|
||||
|
||||
// Git hooks should be blocked regardless of allowGitConfig
|
||||
for _, allowGitConfig := range []bool{true, false} {
|
||||
patterns := GetMandatoryDenyPatterns(cwd, allowGitConfig)
|
||||
|
||||
foundHooksPath := false
|
||||
foundHooksGlob := false
|
||||
|
||||
for _, p := range patterns {
|
||||
if p == filepath.Join(cwd, ".git/hooks") {
|
||||
foundHooksPath = true
|
||||
}
|
||||
if strings.Contains(p, ".git/hooks") && strings.HasPrefix(p, "**") {
|
||||
foundHooksGlob = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundHooksPath || !foundHooksGlob {
|
||||
t.Errorf("Git hooks should always be blocked (allowGitConfig=%v)", allowGitConfig)
|
||||
}
|
||||
}
|
||||
}
|
||||
278
internal/sandbox/utils_test.go
Normal file
278
internal/sandbox/utils_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContainsGlobChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
pattern string
|
||||
want bool
|
||||
}{
|
||||
{"/path/to/file", false},
|
||||
{"/path/to/dir/", false},
|
||||
{"relative/path", false},
|
||||
{"/path/with/asterisk/*", true},
|
||||
{"/path/with/question?", true},
|
||||
{"/path/with/brackets[a-z]", true},
|
||||
{"/path/**/*.go", true},
|
||||
{"*.txt", true},
|
||||
{"file[0-9].txt", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern, func(t *testing.T) {
|
||||
got := ContainsGlobChars(tt.pattern)
|
||||
if got != tt.want {
|
||||
t.Errorf("ContainsGlobChars(%q) = %v, want %v", tt.pattern, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveTrailingGlobSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"/path/to/dir/**", "/path/to/dir"},
|
||||
{"/path/to/dir", "/path/to/dir"},
|
||||
{"/path/**/**", "/path/**"},
|
||||
{"/**", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := RemoveTrailingGlobSuffix(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("RemoveTrailingGlobSuffix(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePath(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "tilde alone",
|
||||
input: "~",
|
||||
want: home,
|
||||
},
|
||||
{
|
||||
name: "tilde with path",
|
||||
input: "~/Documents",
|
||||
want: filepath.Join(home, "Documents"),
|
||||
},
|
||||
{
|
||||
name: "absolute path",
|
||||
input: "/usr/bin",
|
||||
want: "/usr/bin",
|
||||
},
|
||||
{
|
||||
name: "relative dot path",
|
||||
input: "./subdir",
|
||||
want: filepath.Join(cwd, "subdir"),
|
||||
},
|
||||
{
|
||||
name: "relative parent path",
|
||||
input: "../sibling",
|
||||
want: filepath.Join(filepath.Dir(cwd), "sibling"),
|
||||
},
|
||||
{
|
||||
name: "glob pattern preserved",
|
||||
input: "/path/**/*.go",
|
||||
want: "/path/**/*.go",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NormalizePath(tt.input)
|
||||
|
||||
// For paths that involve symlink resolution, we just check the result is reasonable
|
||||
if strings.Contains(tt.input, "**") || strings.Contains(tt.input, "*") {
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For tilde and relative paths, we check prefixes since symlinks may resolve differently
|
||||
if tt.input == "~" {
|
||||
if got != home && !strings.HasPrefix(got, "/") {
|
||||
t.Errorf("NormalizePath(%q) = %q, expected home directory", tt.input, got)
|
||||
}
|
||||
} else if strings.HasPrefix(tt.input, "~/") {
|
||||
if !strings.HasPrefix(got, home) && !strings.HasPrefix(got, "/") {
|
||||
t.Errorf("NormalizePath(%q) = %q, expected path under home", tt.input, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateProxyEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
httpPort int
|
||||
socksPort int
|
||||
wantEnvs []string
|
||||
dontWant []string
|
||||
}{
|
||||
{
|
||||
name: "no ports",
|
||||
httpPort: 0,
|
||||
socksPort: 0,
|
||||
wantEnvs: []string{
|
||||
"FENCE_SANDBOX=1",
|
||||
"TMPDIR=/tmp/fence",
|
||||
},
|
||||
dontWant: []string{
|
||||
"HTTP_PROXY=",
|
||||
"HTTPS_PROXY=",
|
||||
"ALL_PROXY=",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "http port only",
|
||||
httpPort: 8080,
|
||||
socksPort: 0,
|
||||
wantEnvs: []string{
|
||||
"FENCE_SANDBOX=1",
|
||||
"HTTP_PROXY=http://localhost:8080",
|
||||
"HTTPS_PROXY=http://localhost:8080",
|
||||
"http_proxy=http://localhost:8080",
|
||||
"https_proxy=http://localhost:8080",
|
||||
"NO_PROXY=",
|
||||
"no_proxy=",
|
||||
},
|
||||
dontWant: []string{
|
||||
"ALL_PROXY=",
|
||||
"all_proxy=",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "socks port only",
|
||||
httpPort: 0,
|
||||
socksPort: 1080,
|
||||
wantEnvs: []string{
|
||||
"FENCE_SANDBOX=1",
|
||||
"ALL_PROXY=socks5h://localhost:1080",
|
||||
"all_proxy=socks5h://localhost:1080",
|
||||
"FTP_PROXY=socks5h://localhost:1080",
|
||||
"GIT_SSH_COMMAND=",
|
||||
},
|
||||
dontWant: []string{
|
||||
"HTTP_PROXY=",
|
||||
"HTTPS_PROXY=",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both ports",
|
||||
httpPort: 8080,
|
||||
socksPort: 1080,
|
||||
wantEnvs: []string{
|
||||
"FENCE_SANDBOX=1",
|
||||
"HTTP_PROXY=http://localhost:8080",
|
||||
"HTTPS_PROXY=http://localhost:8080",
|
||||
"ALL_PROXY=socks5h://localhost:1080",
|
||||
"GIT_SSH_COMMAND=",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateProxyEnvVars(tt.httpPort, tt.socksPort)
|
||||
|
||||
// Check expected env vars are present
|
||||
for _, want := range tt.wantEnvs {
|
||||
found := false
|
||||
for _, env := range got {
|
||||
if strings.HasPrefix(env, want) || env == want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("GenerateProxyEnvVars(%d, %d) missing %q", tt.httpPort, tt.socksPort, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check unwanted env vars are not present
|
||||
for _, dontWant := range tt.dontWant {
|
||||
for _, env := range got {
|
||||
if strings.HasPrefix(env, dontWant) {
|
||||
t.Errorf("GenerateProxyEnvVars(%d, %d) should not contain %q, got %q", tt.httpPort, tt.socksPort, dontWant, env)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeSandboxedCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
}{
|
||||
{"simple command", "ls -la"},
|
||||
{"command with spaces", "grep -r 'pattern' /path/to/dir"},
|
||||
{"empty command", ""},
|
||||
{"special chars", "echo $HOME && ls | grep foo"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
encoded := EncodeSandboxedCommand(tt.command)
|
||||
if encoded == "" && tt.command != "" {
|
||||
t.Error("EncodeSandboxedCommand returned empty string")
|
||||
}
|
||||
|
||||
// Roundtrip test
|
||||
decoded, err := DecodeSandboxedCommand(encoded)
|
||||
if err != nil {
|
||||
t.Errorf("DecodeSandboxedCommand failed: %v", err)
|
||||
}
|
||||
|
||||
// Commands are truncated to 100 chars
|
||||
expected := tt.command
|
||||
if len(expected) > 100 {
|
||||
expected = expected[:100]
|
||||
}
|
||||
if decoded != expected {
|
||||
t.Errorf("Roundtrip failed: got %q, want %q", decoded, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeSandboxedCommandTruncation(t *testing.T) {
|
||||
// Test that long commands are truncated
|
||||
longCommand := strings.Repeat("a", 200)
|
||||
encoded := EncodeSandboxedCommand(longCommand)
|
||||
decoded, _ := DecodeSandboxedCommand(encoded)
|
||||
|
||||
if len(decoded) != 100 {
|
||||
t.Errorf("Expected truncated command of 100 chars, got %d", len(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeSandboxedCommandInvalid(t *testing.T) {
|
||||
_, err := DecodeSandboxedCommand("not-valid-base64!!!")
|
||||
if err == nil {
|
||||
t.Error("DecodeSandboxedCommand should fail on invalid base64")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user