From 09f84782b0e1ee4736f5a4344e8f72f018464459 Mon Sep 17 00:00:00 2001 From: Sora39831 <540587985@qq.com> Date: Fri, 3 Apr 2026 08:44:51 +0800 Subject: [PATCH] test: add unit tests for critical modules and fix flaky redirect middleware - Add tests for config, database, model, util/common, util/crypto, util/random, web/middleware, web/service, and xray packages - Fix redirect middleware using slice instead of map to guarantee deterministic longest-prefix-first matching order --- config/config_test.go | 131 ++++++++++++ database/db_test.go | 214 ++++++++++++++++++++ database/model/model_test.go | 61 ++++++ util/common/common_test.go | 162 +++++++++++++++ util/crypto/crypto_test.go | 72 +++++++ util/random/random_test.go | 67 +++++++ web/middleware/middleware_test.go | 320 ++++++++++++++++++++++++++++++ web/middleware/redirect.go | 21 +- web/service/server_test.go | 100 ++++++++++ web/service/user_test.go | 147 ++++++++++++++ xray/api_test.go | 172 ++++++++++++++++ xray/config_test.go | 125 ++++++++++++ xray/log_writer_test.go | 91 +++++++++ 13 files changed, 1674 insertions(+), 9 deletions(-) create mode 100644 config/config_test.go create mode 100644 database/db_test.go create mode 100644 database/model/model_test.go create mode 100644 util/common/common_test.go create mode 100644 util/crypto/crypto_test.go create mode 100644 util/random/random_test.go create mode 100644 web/middleware/middleware_test.go create mode 100644 web/service/server_test.go create mode 100644 web/service/user_test.go create mode 100644 xray/api_test.go create mode 100644 xray/config_test.go create mode 100644 xray/log_writer_test.go diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 00000000..d911c7ac --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,131 @@ +package config + +import ( + "strings" + "testing" +) + +func TestGetVersion(t *testing.T) { + v := GetVersion() + if v == "" { + // version file might be empty in test, that's ok + t.Log("version is empty (expected in test environment)") + } +} + +func TestGetName(t *testing.T) { + n := GetName() + if n == "" { + t.Fatal("name should not be empty") + } + if strings.TrimSpace(n) != n { + t.Error("name should be trimmed") + } +} + +func TestIsDebugDefault(t *testing.T) { + t.Setenv("XUI_DEBUG", "") + if IsDebug() { + t.Error("IsDebug should return false by default") + } +} + +func TestIsDebugTrue(t *testing.T) { + t.Setenv("XUI_DEBUG", "true") + if !IsDebug() { + t.Error("IsDebug should return true when XUI_DEBUG=true") + } +} + +func TestIsDebugFalse(t *testing.T) { + t.Setenv("XUI_DEBUG", "false") + if IsDebug() { + t.Error("IsDebug should return false when XUI_DEBUG=false") + } +} + +func TestGetLogLevelDefault(t *testing.T) { + t.Setenv("XUI_DEBUG", "") + t.Setenv("XUI_LOG_LEVEL", "") + if GetLogLevel() != Info { + t.Errorf("default log level should be Info, got %s", GetLogLevel()) + } +} + +func TestGetLogLevelDebug(t *testing.T) { + t.Setenv("XUI_DEBUG", "true") + if GetLogLevel() != Debug { + t.Errorf("log level should be Debug when XUI_DEBUG=true, got %s", GetLogLevel()) + } +} + +func TestGetLogLevelCustom(t *testing.T) { + t.Setenv("XUI_DEBUG", "") + t.Setenv("XUI_LOG_LEVEL", "warning") + if GetLogLevel() != Warning { + t.Errorf("log level should be Warning, got %s", GetLogLevel()) + } +} + +func TestGetBinFolderPathDefault(t *testing.T) { + t.Setenv("XUI_BIN_FOLDER", "") + if GetBinFolderPath() != "bin" { + t.Errorf("default bin folder should be 'bin', got %s", GetBinFolderPath()) + } +} + +func TestGetBinFolderPathCustom(t *testing.T) { + t.Setenv("XUI_BIN_FOLDER", "/custom/bin") + if GetBinFolderPath() != "/custom/bin" { + t.Errorf("bin folder should be '/custom/bin', got %s", GetBinFolderPath()) + } +} + +func TestGetDBFolderPathDefault(t *testing.T) { + t.Setenv("XUI_DB_FOLDER", "") + folder := GetDBFolderPath() + // On Linux without env var, should be "/etc/x-ui" + if folder != "/etc/x-ui" { + t.Errorf("default DB folder should be '/etc/x-ui', got %s", folder) + } +} + +func TestGetDBFolderPathCustom(t *testing.T) { + t.Setenv("XUI_DB_FOLDER", "/tmp/test-db") + if GetDBFolderPath() != "/tmp/test-db" { + t.Errorf("DB folder should be '/tmp/test-db', got %s", GetDBFolderPath()) + } +} + +func TestGetDBPath(t *testing.T) { + t.Setenv("XUI_DB_FOLDER", "/tmp/test") + dbPath := GetDBPath() + expected := "/tmp/test/" + GetName() + ".db" + if dbPath != expected { + t.Errorf("GetDBPath() = %q, want %q", dbPath, expected) + } +} + +func TestGetSettingPath(t *testing.T) { + t.Setenv("XUI_DB_FOLDER", "/tmp/test") + settingPath := GetSettingPath() + expected := "/tmp/test/" + GetName() + ".json" + if settingPath != expected { + t.Errorf("GetSettingPath() = %q, want %q", settingPath, expected) + } +} + +func TestGetLogFolderDefault(t *testing.T) { + t.Setenv("XUI_LOG_FOLDER", "") + folder := GetLogFolder() + if folder != "/var/log/x-ui" { + t.Errorf("default log folder should be '/var/log/x-ui', got %s", folder) + } +} + +func TestGetLogFolderCustom(t *testing.T) { + t.Setenv("XUI_LOG_FOLDER", "/custom/logs") + if GetLogFolder() != "/custom/logs" { + t.Errorf("log folder should be '/custom/logs', got %s", GetLogFolder()) + } +} diff --git a/database/db_test.go b/database/db_test.go new file mode 100644 index 00000000..5bd42d7e --- /dev/null +++ b/database/db_test.go @@ -0,0 +1,214 @@ +package database + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/mhsanaei/3x-ui/v2/database/model" +) + +func setupTestDB(t *testing.T) { + t.Helper() + tmpDir := t.TempDir() + t.Setenv("XUI_DEBUG", "") + dbPath := filepath.Join(tmpDir, "test.db") + if err := InitDB(dbPath); err != nil { + t.Fatalf("InitDB failed: %v", err) + } + t.Cleanup(func() { + CloseDB() + }) +} + +func TestIsSQLiteDB_ValidFile(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "valid.db") + if err := InitDB(dbPath); err != nil { + t.Fatalf("InitDB failed: %v", err) + } + defer CloseDB() + + f, err := os.Open(dbPath) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + defer f.Close() + + ok, err := IsSQLiteDB(f) + if err != nil { + t.Fatalf("IsSQLiteDB error: %v", err) + } + if !ok { + t.Error("IsSQLiteDB should return true for a valid SQLite file") + } +} + +func TestIsSQLiteDB_InvalidFile(t *testing.T) { + tmpDir := t.TempDir() + notADB := filepath.Join(tmpDir, "notdb.txt") + if err := os.WriteFile(notADB, []byte("this is not a database"), 0644); err != nil { + t.Fatal(err) + } + + f, err := os.Open(notADB) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + ok, err := IsSQLiteDB(f) + if err != nil { + t.Fatalf("IsSQLiteDB error: %v", err) + } + if ok { + t.Error("IsSQLiteDB should return false for a non-SQLite file") + } +} + +func TestIsSQLiteDB_EmptyFile(t *testing.T) { + tmpDir := t.TempDir() + empty := filepath.Join(tmpDir, "empty.db") + if err := os.WriteFile(empty, []byte{}, 0644); err != nil { + t.Fatal(err) + } + + f, err := os.Open(empty) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + ok, err := IsSQLiteDB(f) + // Empty file returns EOF since there aren't enough bytes to read + if err == nil && ok { + t.Error("IsSQLiteDB should return false for an empty file") + } +} + +func TestIsSQLiteDB_WrongContent(t *testing.T) { + // File with 16 bytes (matching SQLite header length) but wrong signature content + r := bytes.NewReader([]byte("SQLite for !!")) + ok, err := IsSQLiteDB(r) + if err != nil { + t.Fatalf("IsSQLiteDB error: %v", err) + } + if ok { + t.Error("IsSQLiteDB should return false for wrong signature content") + } +} + +func TestInitDB_CreatesTables(t *testing.T) { + setupTestDB(t) + + // Verify all tables exist by querying them + tables := []string{"users", "inbounds", "outbound_traffics", "settings", "inbound_client_ips", "client_traffics", "history_of_seeders"} + for _, table := range tables { + var count int64 + if err := db.Table(table).Count(&count).Error; err != nil { + t.Errorf("table %q should exist but got error: %v", table, err) + } + } +} + +func TestInitDB_CreatesDefaultUser(t *testing.T) { + setupTestDB(t) + + var user model.User + if err := db.First(&user).Error; err != nil { + t.Fatalf("should have a default user: %v", err) + } + if user.Username != "admin" { + t.Errorf("default username should be 'admin', got %q", user.Username) + } + if user.Role != "admin" { + t.Errorf("default role should be 'admin', got %q", user.Role) + } + // Password should be a bcrypt hash, not plaintext + if user.Password == "admin" { + t.Error("default password should be hashed, not plaintext") + } +} + +func TestInitDB_Idempotent(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XUI_DEBUG", "") + dbPath := filepath.Join(tmpDir, "idempotent.db") + + // First init + if err := InitDB(dbPath); err != nil { + t.Fatalf("first InitDB failed: %v", err) + } + CloseDB() + + // Second init on the same file should not fail + if err := InitDB(dbPath); err != nil { + t.Fatalf("second InitDB failed: %v", err) + } + defer CloseDB() + + // Should still have exactly one default user + var count int64 + db.Model(&model.User{}).Count(&count) + if count != 1 { + t.Errorf("expected 1 user after second init, got %d", count) + } +} + +func TestValidateSQLiteDB_ValidDB(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "valid.db") + if err := InitDB(dbPath); err != nil { + t.Fatalf("InitDB failed: %v", err) + } + CloseDB() + + if err := ValidateSQLiteDB(dbPath); err != nil { + t.Errorf("ValidateSQLiteDB should pass for valid DB: %v", err) + } +} + +func TestValidateSQLiteDB_NonExistent(t *testing.T) { + err := ValidateSQLiteDB("/tmp/does-not-exist-12345.db") + if err == nil { + t.Error("ValidateSQLiteDB should fail for non-existent file") + } +} + +func TestValidateSQLiteDB_CorruptDB(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupt.db") + // Write garbage that looks like SQLite header but is corrupt + garbage := make([]byte, 4096) + copy(garbage, []byte("SQLite format 3\x00")) + if err := os.WriteFile(dbPath, garbage, 0644); err != nil { + t.Fatal(err) + } + + err := ValidateSQLiteDB(dbPath) + if err == nil { + t.Error("ValidateSQLiteDB should fail for corrupt DB") + } +} + +func TestIsNotFound(t *testing.T) { + if IsNotFound(nil) { + t.Error("IsNotFound should return false for nil") + } +} + +func TestInitUser_OnlyOnce(t *testing.T) { + setupTestDB(t) + + // initUser should not create a second user when table is not empty + if err := initUser(); err != nil { + t.Fatalf("initUser error: %v", err) + } + + var count int64 + db.Model(&model.User{}).Count(&count) + if count != 1 { + t.Errorf("expected 1 user, got %d", count) + } +} diff --git a/database/model/model_test.go b/database/model/model_test.go new file mode 100644 index 00000000..17c9839c --- /dev/null +++ b/database/model/model_test.go @@ -0,0 +1,61 @@ +package model + +import ( + "testing" +) + +func TestGenXrayInboundConfig_EmptyListen(t *testing.T) { + in := &Inbound{ + Listen: "", + Port: 443, + Protocol: VLESS, + Settings: `{"clients":[]}`, + Tag: "test-inbound", + Sniffing: `{"enabled":true}`, + } + cfg := in.GenXrayInboundConfig() + if cfg == nil { + t.Fatal("GenXrayInboundConfig should not return nil") + } + // Empty listen should default to 0.0.0.0 + expected := `"0.0.0.0"` + if string(cfg.Listen) != expected { + t.Errorf("Listen should default to %s, got %s", expected, string(cfg.Listen)) + } + if cfg.Port != 443 { + t.Errorf("Port should be 443, got %d", cfg.Port) + } + if cfg.Protocol != "vless" { + t.Errorf("Protocol should be vless, got %q", cfg.Protocol) + } + if cfg.Tag != "test-inbound" { + t.Errorf("Tag should be test-inbound, got %q", cfg.Tag) + } +} + +func TestGenXrayInboundConfig_CustomListen(t *testing.T) { + in := &Inbound{ + Listen: "127.0.0.1", + Port: 8080, + Protocol: VMESS, + Tag: "custom", + } + cfg := in.GenXrayInboundConfig() + expected := `"127.0.0.1"` + if string(cfg.Listen) != expected { + t.Errorf("Listen should be %s, got %s", expected, string(cfg.Listen)) + } +} + +func TestGenXrayInboundConfig_EmptySettings(t *testing.T) { + in := &Inbound{ + Port: 443, + Protocol: Trojan, + } + cfg := in.GenXrayInboundConfig() + if cfg == nil { + t.Fatal("GenXrayInboundConfig should not return nil") + } + // Empty string Settings produces a nil RawMessage since json_util.RawMessage("") may be nil + // Just verify no panic occurred +} diff --git a/util/common/common_test.go b/util/common/common_test.go new file mode 100644 index 00000000..34294354 --- /dev/null +++ b/util/common/common_test.go @@ -0,0 +1,162 @@ +package common + +import ( + "errors" + "strings" + "testing" +) + +func TestNewErrorf(t *testing.T) { + err := NewErrorf("invalid port: %d", 8080) + if err == nil { + t.Fatal("NewErrorf should return non-nil error") + } + expected := "invalid port: 8080" + if err.Error() != expected { + t.Errorf("NewErrorf returned %q, expected %q", err.Error(), expected) + } +} + +func TestNewError(t *testing.T) { + err := NewError("something", " went wrong") + if err == nil { + t.Fatal("NewError should return non-nil error") + } + if !strings.Contains(err.Error(), "something") { + t.Errorf("NewError should contain 'something', got %q", err.Error()) + } +} + +func TestRecoverWithoutPanic(t *testing.T) { + recovered := Recover("") + if recovered != nil { + t.Errorf("Recover should return nil when no panic occurred, got %v", recovered) + } +} + +func TestFormatTrafficBytes(t *testing.T) { + tests := []struct { + input int64 + expected string + }{ + {0, "0.00B"}, + {1, "1.00B"}, + {512, "512.00B"}, + {1023, "1023.00B"}, + } + for _, tt := range tests { + result := FormatTraffic(tt.input) + if result != tt.expected { + t.Errorf("FormatTraffic(%d) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestFormatTrafficKB(t *testing.T) { + tests := []struct { + input int64 + expected string + }{ + {1024, "1.00KB"}, + {1536, "1.50KB"}, + {2048, "2.00KB"}, + } + for _, tt := range tests { + result := FormatTraffic(tt.input) + if result != tt.expected { + t.Errorf("FormatTraffic(%d) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestFormatTrafficMB(t *testing.T) { + result := FormatTraffic(1048576) // 1 MB + expected := "1.00MB" + if result != expected { + t.Errorf("FormatTraffic(1048576) = %q, want %q", result, expected) + } +} + +func TestFormatTrafficGB(t *testing.T) { + result := FormatTraffic(1073741824) // 1 GB + expected := "1.00GB" + if result != expected { + t.Errorf("FormatTraffic(1073741824) = %q, want %q", result, expected) + } +} + +func TestFormatTrafficTB(t *testing.T) { + result := FormatTraffic(1099511627776) // 1 TB + expected := "1.00TB" + if result != expected { + t.Errorf("FormatTraffic(1099511627776) = %q, want %q", result, expected) + } +} + +func TestFormatTrafficPB(t *testing.T) { + result := FormatTraffic(1125899906842624) // 1 PB + expected := "1.00PB" + if result != expected { + t.Errorf("FormatTraffic(1125899906842624) = %q, want %q", result, expected) + } +} + +func TestFormatTrafficLargePB(t *testing.T) { + // Value exceeding PB should stay in PB + result := FormatTraffic(11258999068426240) // 10 PB + if !strings.HasSuffix(result, "PB") { + t.Errorf("FormatTraffic should cap at PB, got %q", result) + } +} + +func TestCombineAllNil(t *testing.T) { + err := Combine(nil, nil, nil) + if err != nil { + t.Errorf("Combine(nil, nil, nil) should return nil, got %v", err) + } +} + +func TestCombineNoArgs(t *testing.T) { + err := Combine() + if err != nil { + t.Errorf("Combine() should return nil, got %v", err) + } +} + +func TestCombineSingleError(t *testing.T) { + input := errors.New("test error") + err := Combine(input) + if err == nil { + t.Fatal("Combine should return non-nil when an error is present") + } + if !strings.Contains(err.Error(), "test error") { + t.Errorf("Combine should contain 'test error', got %q", err.Error()) + } +} + +func TestCombineMultipleErrors(t *testing.T) { + err1 := errors.New("error 1") + err2 := errors.New("error 2") + combined := Combine(err1, nil, err2) + if combined == nil { + t.Fatal("Combine should return non-nil when errors are present") + } + s := combined.Error() + if !strings.Contains(s, "error 1") { + t.Errorf("Combined error should contain 'error 1', got %q", s) + } + if !strings.Contains(s, "error 2") { + t.Errorf("Combined error should contain 'error 2', got %q", s) + } +} + +func TestCombineFiltersNils(t *testing.T) { + err1 := errors.New("real error") + combined := Combine(nil, err1, nil) + if combined == nil { + t.Fatal("Combine should return non-nil when at least one error is present") + } + if !strings.Contains(combined.Error(), "real error") { + t.Errorf("Combined error should contain 'real error', got %q", combined.Error()) + } +} diff --git a/util/crypto/crypto_test.go b/util/crypto/crypto_test.go new file mode 100644 index 00000000..2d341ef8 --- /dev/null +++ b/util/crypto/crypto_test.go @@ -0,0 +1,72 @@ +package crypto + +import ( + "strings" + "testing" +) + +func TestHashPasswordAsBcrypt(t *testing.T) { + hash, err := HashPasswordAsBcrypt("password123") + if err != nil { + t.Fatalf("HashPasswordAsBcrypt failed: %v", err) + } + if hash == "" { + t.Fatal("hash should not be empty") + } + if hash == "password123" { + t.Fatal("hash should not equal the plaintext password") + } + if !strings.HasPrefix(hash, "$2a$") && !strings.HasPrefix(hash, "$2b$") { + t.Fatalf("hash should have bcrypt prefix, got: %s", hash[:4]) + } +} + +func TestHashAndCheckRoundTrip(t *testing.T) { + passwords := []string{ + "password123", + "", + "very-long-password-with-special-chars-!@#$%^&*()", + "unicode-密码-test", + } + for _, pw := range passwords { + hash, err := HashPasswordAsBcrypt(pw) + if err != nil { + t.Fatalf("HashPasswordAsBcrypt(%q) failed: %v", pw, err) + } + if !CheckPasswordHash(hash, pw) { + t.Errorf("CheckPasswordHash should return true for correct password %q", pw) + } + } +} + +func TestCheckPasswordHashWrongPassword(t *testing.T) { + hash, err := HashPasswordAsBcrypt("correct-password") + if err != nil { + t.Fatalf("HashPasswordAsBcrypt failed: %v", err) + } + if CheckPasswordHash(hash, "wrong-password") { + t.Error("CheckPasswordHash should return false for wrong password") + } +} + +func TestCheckPasswordHashInvalidHash(t *testing.T) { + if CheckPasswordHash("not-a-valid-hash", "password") { + t.Error("CheckPasswordHash should return false for invalid hash") + } +} + +func TestDifferentPasswordsProduceDifferentHashes(t *testing.T) { + hash1, _ := HashPasswordAsBcrypt("password1") + hash2, _ := HashPasswordAsBcrypt("password2") + if hash1 == hash2 { + t.Error("different passwords should produce different hashes") + } +} + +func TestSamePasswordProducesDifferentHashes(t *testing.T) { + hash1, _ := HashPasswordAsBcrypt("same-password") + hash2, _ := HashPasswordAsBcrypt("same-password") + if hash1 == hash2 { + t.Error("bcrypt should use different salts, producing different hashes for same password") + } +} diff --git a/util/random/random_test.go b/util/random/random_test.go new file mode 100644 index 00000000..47b4e3d1 --- /dev/null +++ b/util/random/random_test.go @@ -0,0 +1,67 @@ +package random + +import ( + "testing" +) + +func TestSeqLength(t *testing.T) { + lengths := []int{0, 1, 10, 32, 64, 128} + for _, n := range lengths { + s := Seq(n) + if len([]rune(s)) != n { + t.Errorf("Seq(%d) returned string of length %d", n, len([]rune(s))) + } + } +} + +func TestSeqCharset(t *testing.T) { + allChars := make(map[rune]bool) + for _, r := range allSeq { + allChars[r] = true + } + + s := Seq(1000) + for i, r := range s { + if !allChars[r] { + t.Errorf("Seq produced invalid character %q at index %d", r, i) + } + } +} + +func TestSeqUniqueness(t *testing.T) { + // Generate several strings and check they're not all identical + seen := make(map[string]bool) + for range 10 { + seen[Seq(32)] = true + } + if len(seen) < 2 { + t.Error("Seq(32) produced identical strings across 10 calls") + } +} + +func TestSeqEmpty(t *testing.T) { + s := Seq(0) + if s != "" { + t.Errorf("Seq(0) should return empty string, got %q", s) + } +} + +func TestNumRange(t *testing.T) { + for _, n := range []int{1, 5, 10, 100, 1000} { + for range 100 { + r := Num(n) + if r < 0 || r >= n { + t.Errorf("Num(%d) returned %d, expected [0, %d)", n, r, n) + } + } + } +} + +func TestNumOne(t *testing.T) { + for range 50 { + r := Num(1) + if r != 0 { + t.Errorf("Num(1) should always return 0, got %d", r) + } + } +} diff --git a/web/middleware/middleware_test.go b/web/middleware/middleware_test.go new file mode 100644 index 00000000..0e9c5bdc --- /dev/null +++ b/web/middleware/middleware_test.go @@ -0,0 +1,320 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestRedirectMiddleware_XUIToPanel(t *testing.T) { + r := gin.New() + r.Use(RedirectMiddleware("/")) + r.GET("/panel/*path", func(c *gin.Context) { + c.String(http.StatusOK, "panel") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/xui/settings", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusMovedPermanently { + t.Errorf("expected 301, got %d", w.Code) + } + loc := w.Header().Get("Location") + if loc != "/panel/settings" { + t.Errorf("expected redirect to /panel/settings, got %q", loc) + } +} + +func TestRedirectMiddleware_XUIAPIToPanelAPI(t *testing.T) { + r := gin.New() + r.Use(RedirectMiddleware("/")) + r.GET("/panel/api/*path", func(c *gin.Context) { + c.String(http.StatusOK, "api") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/xui/API/inbounds", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusMovedPermanently { + t.Errorf("expected 301, got %d", w.Code) + } + loc := w.Header().Get("Location") + if loc != "/panel/api/inbounds" { + t.Errorf("expected redirect to /panel/api/inbounds, got %q", loc) + } +} + +func TestRedirectMiddleware_PanelAPICase(t *testing.T) { + r := gin.New() + r.Use(RedirectMiddleware("/")) + r.GET("/panel/api/*path", func(c *gin.Context) { + c.String(http.StatusOK, "api") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/panel/API/list", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusMovedPermanently { + t.Errorf("expected 301, got %d", w.Code) + } + loc := w.Header().Get("Location") + if loc != "/panel/api/list" { + t.Errorf("expected redirect to /panel/api/list, got %q", loc) + } +} + +func TestRedirectMiddleware_NoRedirect(t *testing.T) { + r := gin.New() + r.Use(RedirectMiddleware("/")) + r.GET("/panel/settings", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/panel/settings", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +func TestRedirectMiddleware_WithBasePath(t *testing.T) { + r := gin.New() + r.Use(RedirectMiddleware("/base/")) + r.GET("/base/panel/*path", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/base/xui/settings", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusMovedPermanently { + t.Errorf("expected 301, got %d", w.Code) + } + loc := w.Header().Get("Location") + if loc != "/base/panel/settings" { + t.Errorf("expected redirect to /base/panel/settings, got %q", loc) + } +} + +func TestDomainValidatorMiddleware_MatchingDomain(t *testing.T) { + r := gin.New() + r.Use(DomainValidatorMiddleware("example.com")) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Host = "example.com" + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +func TestDomainValidatorMiddleware_MatchingDomainWithPort(t *testing.T) { + r := gin.New() + r.Use(DomainValidatorMiddleware("example.com")) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Host = "example.com:8443" + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200 for matching domain with port, got %d", w.Code) + } +} + +func TestDomainValidatorMiddleware_NonMatchingDomain(t *testing.T) { + r := gin.New() + r.Use(DomainValidatorMiddleware("example.com")) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Host = "evil.com" + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) + } +} + +func TestDomainValidatorMiddleware_Subdomain(t *testing.T) { + r := gin.New() + r.Use(DomainValidatorMiddleware("example.com")) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Host = "sub.example.com" + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("expected 403 for subdomain, got %d", w.Code) + } +} + +func TestRateLimitMiddleware_FirstRequest(t *testing.T) { + r := gin.New() + r.Use(RateLimitMiddleware(5, time.Minute)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200 for first request, got %d", w.Code) + } +} + +func TestRateLimitMiddleware_WithinLimit(t *testing.T) { + r := gin.New() + r.Use(RateLimitMiddleware(3, time.Minute)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + for i := range 3 { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.2:12345" + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", i+1, w.Code) + } + } +} + +func TestRateLimitMiddleware_ExceedsLimit(t *testing.T) { + r := gin.New() + r.Use(RateLimitMiddleware(2, time.Minute)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // First 2 should pass + for range 2 { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.3:12345" + r.ServeHTTP(w, req) + } + + // 3rd should be rate limited + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.3:12345" + r.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", w.Code) + } +} + +func TestRateLimitMiddleware_XRealIP(t *testing.T) { + r := gin.New() + r.Use(RateLimitMiddleware(2, time.Minute)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + for range 2 { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-IP", "10.0.0.1") + r.ServeHTTP(w, req) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-IP", "10.0.0.1") + r.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429 with X-Real-IP, got %d", w.Code) + } +} + +func TestRateLimitMiddleware_XForwardedFor(t *testing.T) { + r := gin.New() + r.Use(RateLimitMiddleware(2, time.Minute)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + for range 2 { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3") + r.ServeHTTP(w, req) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3") + r.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429 with X-Forwarded-For, got %d", w.Code) + } +} + +func TestRateLimitMiddleware_DifferentIPsIndependent(t *testing.T) { + r := gin.New() + r.Use(RateLimitMiddleware(1, time.Minute)) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // Exhaust limit for IP 1 + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-IP", "10.0.0.10") + r.ServeHTTP(w, req) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-IP", "10.0.0.10") + r.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("IP 1 second request should be 429, got %d", w.Code) + } + + // IP 2 should still be allowed + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-IP", "10.0.0.20") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("different IP should get 200, got %d", w.Code) + } +} diff --git a/web/middleware/redirect.go b/web/middleware/redirect.go index 966d897c..5c6cff1d 100644 --- a/web/middleware/redirect.go +++ b/web/middleware/redirect.go @@ -11,17 +11,20 @@ import ( // It provides backward compatibility by redirecting old '/xui' paths to new '/panel' paths, // including API endpoints. The middleware performs permanent redirects (301) for SEO purposes. func RedirectMiddleware(basePath string) gin.HandlerFunc { - return func(c *gin.Context) { - // Redirect from old '/xui' path to '/panel' - redirects := map[string]string{ - "panel/API": "panel/api", - "xui/API": "panel/api", - "xui": "panel", - } + // Use a slice to guarantee longest-prefix-first matching order. + // A map would have nondeterministic iteration, causing "/xui/API" to + // sometimes match the shorter "/xui" rule instead. + redirects := []struct{ from, to string }{ + {"panel/API", "panel/api"}, + {"xui/API", "panel/api"}, + {"xui", "panel"}, + } + return func(c *gin.Context) { path := c.Request.URL.Path - for from, to := range redirects { - from, to = basePath+from, basePath+to + for _, r := range redirects { + from := basePath + r.from + to := basePath + r.to if strings.HasPrefix(path, from) { newPath := to + path[len(from):] diff --git a/web/service/server_test.go b/web/service/server_test.go new file mode 100644 index 00000000..a04b5e89 --- /dev/null +++ b/web/service/server_test.go @@ -0,0 +1,100 @@ +package service + +import ( + "testing" +) + +func TestIsValidGeofileName_Valid(t *testing.T) { + svc := &ServerService{} + valid := []string{ + "geoip.dat", + "geosite.dat", + "geoip_IR.dat", + "geoip_RU.dat", + "geosite_IR.dat", + "custom-file_v2.dat", + } + for _, name := range valid { + if !svc.IsValidGeofileName(name) { + t.Errorf("IsValidGeofileName(%q) should return true", name) + } + } +} + +func TestIsValidGeofileName_PathTraversal(t *testing.T) { + svc := &ServerService{} + invalid := []string{ + "../geoip.dat", + "../../etc/passwd", + "subdir/geoip.dat", + "geoip.dat/../../../etc", + "..\\geoip.dat", + } + for _, name := range invalid { + if svc.IsValidGeofileName(name) { + t.Errorf("IsValidGeofileName(%q) should return false (path traversal)", name) + } + } +} + +func TestIsValidGeofileName_Empty(t *testing.T) { + svc := &ServerService{} + if svc.IsValidGeofileName("") { + t.Error("IsValidGeofileName(\"\") should return false") + } +} + +func TestIsValidGeofileName_NoDatExtension(t *testing.T) { + svc := &ServerService{} + invalid := []string{ + "geoip.txt", + "geosite", + "file.exe", + "script.sh", + } + for _, name := range invalid { + if svc.IsValidGeofileName(name) { + t.Errorf("IsValidGeofileName(%q) should return false (no .dat extension)", name) + } + } +} + +func TestIsValidGeofileName_SpecialChars(t *testing.T) { + svc := &ServerService{} + invalid := []string{ + "geoip$.dat", + "geoip!.dat", + "geoip;.dat", + "geoip .dat", + "geoip@attack.dat", + } + for _, name := range invalid { + if svc.IsValidGeofileName(name) { + t.Errorf("IsValidGeofileName(%q) should return false (special chars)", name) + } + } +} + +func TestLogEntryContains(t *testing.T) { + tests := []struct { + line string + suffixes []string + want bool + }{ + // The implementation checks strings.Contains(line, sfx+"]") + {"line with freedom]", []string{"freedom"}, true}, + {"line with blackhole]", []string{"blackhole"}, true}, + {"freedom outbound", []string{"freedom"}, false}, + {"blackhole outbound", []string{"blackhole"}, false}, + {"freedom outbound", []string{"blackhole"}, false}, + {"some log line", []string{}, false}, + {"line with freedom] and blackhole]", []string{"freedom", "blackhole"}, true}, + {"line with freedom] and blackhole]", []string{"other"}, false}, + } + for _, tt := range tests { + got := logEntryContains(tt.line, tt.suffixes) + if got != tt.want { + t.Errorf("logEntryContains(%q, %v) = %v, want %v", tt.line, tt.suffixes, got, tt.want) + } + } +} diff --git a/web/service/user_test.go b/web/service/user_test.go new file mode 100644 index 00000000..26a83c83 --- /dev/null +++ b/web/service/user_test.go @@ -0,0 +1,147 @@ +package service + +import ( + "os" + "path/filepath" + "testing" + + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/util/crypto" +) + +func setupTestDB(t *testing.T) { + t.Helper() + tmpDir := t.TempDir() + t.Setenv("XUI_DEBUG", "") + t.Setenv("XUI_DB_FOLDER", tmpDir) + dbPath := filepath.Join(tmpDir, "test.db") + if err := database.InitDB(dbPath); err != nil { + t.Fatalf("InitDB failed: %v", err) + } + t.Cleanup(func() { + database.CloseDB() + }) +} + +func TestGetFirstUser(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + user, err := svc.GetFirstUser() + if err != nil { + t.Fatalf("GetFirstUser error: %v", err) + } + if user.Username != "admin" { + t.Errorf("expected username 'admin', got %q", user.Username) + } +} + +func TestCheckUser_ValidCredentials(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + user, err := svc.CheckUser("admin", "admin", "") + if err != nil { + t.Fatalf("CheckUser error: %v", err) + } + if user.Username != "admin" { + t.Errorf("expected username 'admin', got %q", user.Username) + } +} + +func TestCheckUser_WrongPassword(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + _, err := svc.CheckUser("admin", "wrongpassword", "") + if err == nil { + t.Error("CheckUser should fail with wrong password") + } +} + +func TestCheckUser_NonExistentUser(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + _, err := svc.CheckUser("nonexistent", "password", "") + if err == nil { + t.Error("CheckUser should fail for non-existent user") + } +} + +func TestUpdateFirstUser(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + err := svc.UpdateFirstUser("newadmin", "newpassword123") + if err != nil { + t.Fatalf("UpdateFirstUser error: %v", err) + } + + // Verify login with new credentials + user, err := svc.CheckUser("newadmin", "newpassword123", "") + if err != nil { + t.Fatalf("CheckUser with new credentials error: %v", err) + } + if user.Username != "newadmin" { + t.Errorf("expected username 'newadmin', got %q", user.Username) + } +} + +func TestUpdateFirstUser_EmptyUsername(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + err := svc.UpdateFirstUser("", "password") + if err == nil { + t.Error("UpdateFirstUser should fail with empty username") + } +} + +func TestUpdateFirstUser_EmptyPassword(t *testing.T) { + setupTestDB(t) + + svc := &UserService{} + err := svc.UpdateFirstUser("admin", "") + if err == nil { + t.Error("UpdateFirstUser should fail with empty password") + } +} + +func TestUpdateFirstUser_CreateWhenNone(t *testing.T) { + // Use a fresh temp dir so no users table data exists + tmpDir := t.TempDir() + os.Setenv("XUI_DEBUG", "") + os.Setenv("XUI_DB_FOLDER", tmpDir) + defer func() { + os.Unsetenv("XUI_DEBUG") + os.Unsetenv("XUI_DB_FOLDER") + }() + + dbPath := filepath.Join(tmpDir, "empty.db") + if err := database.InitDB(dbPath); err != nil { + t.Fatalf("InitDB failed: %v", err) + } + defer database.CloseDB() + + // Delete all users to simulate empty table + database.GetDB().Where("1 = 1").Delete(&model.User{}) + + svc := &UserService{} + err := svc.UpdateFirstUser("firstadmin", "firstpass") + if err != nil { + t.Fatalf("UpdateFirstUser should create user when table is empty: %v", err) + } + + user, err := svc.GetFirstUser() + if err != nil { + t.Fatalf("GetFirstUser error: %v", err) + } + if user.Username != "firstadmin" { + t.Errorf("expected username 'firstadmin', got %q", user.Username) + } + if !crypto.CheckPasswordHash(user.Password, "firstpass") { + t.Error("password hash should match 'firstpass'") + } +} diff --git a/xray/api_test.go b/xray/api_test.go new file mode 100644 index 00000000..97d2b757 --- /dev/null +++ b/xray/api_test.go @@ -0,0 +1,172 @@ +package xray + +import ( + "testing" +) + +func TestProcessTraffic_Inbound(t *testing.T) { + matches := []string{ + "inbound>>>vmess-in>>>traffic>>>uplink", + "inbound", + "vmess-in", + "uplink", + } + trafficMap := make(map[string]*Traffic) + processTraffic(matches, 1024, trafficMap) + + tr, ok := trafficMap["vmess-in"] + if !ok { + t.Fatal("should have vmess-in entry") + } + if !tr.IsInbound { + t.Error("should be inbound") + } + if tr.IsOutbound { + t.Error("should not be outbound") + } + if tr.Tag != "vmess-in" { + t.Errorf("tag should be vmess-in, got %q", tr.Tag) + } + if tr.Up != 1024 { + t.Errorf("up should be 1024, got %d", tr.Up) + } + if tr.Down != 0 { + t.Errorf("down should be 0, got %d", tr.Down) + } +} + +func TestProcessTraffic_Outbound(t *testing.T) { + matches := []string{ + "outbound>>>direct>>>traffic>>>downlink", + "outbound", + "direct", + "downlink", + } + trafficMap := make(map[string]*Traffic) + processTraffic(matches, 2048, trafficMap) + + tr, ok := trafficMap["direct"] + if !ok { + t.Fatal("should have direct entry") + } + if tr.IsOutbound != true { + t.Error("should be outbound") + } + if tr.IsInbound != false { + t.Error("should not be inbound") + } + if tr.Down != 2048 { + t.Errorf("down should be 2048, got %d", tr.Down) + } +} + +func TestProcessTraffic_ApiTagSkipped(t *testing.T) { + matches := []string{ + "inbound>>>api>>>traffic>>>uplink", + "inbound", + "api", + "uplink", + } + trafficMap := make(map[string]*Traffic) + processTraffic(matches, 1024, trafficMap) + + if _, ok := trafficMap["api"]; ok { + t.Error("api tag should be skipped") + } +} + +func TestProcessTraffic_Aggregates(t *testing.T) { + trafficMap := make(map[string]*Traffic) + + // First: uplink + processTraffic([]string{"", "inbound", "test-tag", "uplink"}, 100, trafficMap) + // Second: downlink on same tag + processTraffic([]string{"", "inbound", "test-tag", "downlink"}, 200, trafficMap) + + tr := trafficMap["test-tag"] + if tr.Up != 100 { + t.Errorf("expected up=100, got %d", tr.Up) + } + if tr.Down != 200 { + t.Errorf("expected down=200, got %d", tr.Down) + } +} + +func TestProcessClientTraffic(t *testing.T) { + clientMap := make(map[string]*ClientTraffic) + + processClientTraffic([]string{"", "user@example.com", "uplink"}, 500, clientMap) + processClientTraffic([]string{"", "user@example.com", "downlink"}, 1500, clientMap) + + ct, ok := clientMap["user@example.com"] + if !ok { + t.Fatal("should have client entry") + } + if ct.Email != "user@example.com" { + t.Errorf("email should be user@example.com, got %q", ct.Email) + } + if ct.Up != 500 { + t.Errorf("up should be 500, got %d", ct.Up) + } + if ct.Down != 1500 { + t.Errorf("down should be 1500, got %d", ct.Down) + } +} + +func TestProcessClientTraffic_MultipleClients(t *testing.T) { + clientMap := make(map[string]*ClientTraffic) + + processClientTraffic([]string{"", "user1@test.com", "uplink"}, 100, clientMap) + processClientTraffic([]string{"", "user2@test.com", "uplink"}, 200, clientMap) + + if len(clientMap) != 2 { + t.Errorf("expected 2 clients, got %d", len(clientMap)) + } + if clientMap["user1@test.com"].Up != 100 { + t.Error("user1 up mismatch") + } + if clientMap["user2@test.com"].Up != 200 { + t.Error("user2 up mismatch") + } +} + +func TestMapToSlice_Empty(t *testing.T) { + m := make(map[string]*Traffic) + result := mapToSlice(m) + if len(result) != 0 { + t.Errorf("expected empty slice, got length %d", len(result)) + } +} + +func TestMapToSlice_Nil(t *testing.T) { + var m map[string]*Traffic + result := mapToSlice(m) + if len(result) != 0 { + t.Errorf("expected empty slice for nil map, got length %d", len(result)) + } +} + +func TestMapToSlice_Multiple(t *testing.T) { + m := map[string]*Traffic{ + "a": {Tag: "a", Up: 1}, + "b": {Tag: "b", Up: 2}, + "c": {Tag: "c", Up: 3}, + } + result := mapToSlice(m) + if len(result) != 3 { + t.Errorf("expected 3 elements, got %d", len(result)) + } +} + +func TestXrayAPI_Init_InvalidPort(t *testing.T) { + api := &XrayAPI{} + if err := api.Init(0); err == nil { + t.Error("Init with port 0 should return error") + } + if err := api.Init(-1); err == nil { + t.Error("Init with negative port should return error") + } + if err := api.Init(70000); err == nil { + t.Error("Init with port > 65535 should return error") + } +} diff --git a/xray/config_test.go b/xray/config_test.go new file mode 100644 index 00000000..55005749 --- /dev/null +++ b/xray/config_test.go @@ -0,0 +1,125 @@ +package xray + +import ( + "testing" +) + +func TestInboundConfig_Equals_Equal(t *testing.T) { + a := &InboundConfig{ + Listen: []byte(`"0.0.0.0"`), + Port: 443, + Protocol: "vless", + Settings: []byte(`{}`), + StreamSettings: []byte(`{"network":"tcp"}`), + Tag: "inbound-443", + Sniffing: []byte(`{"enabled":true}`), + } + b := &InboundConfig{ + Listen: []byte(`"0.0.0.0"`), + Port: 443, + Protocol: "vless", + Settings: []byte(`{}`), + StreamSettings: []byte(`{"network":"tcp"}`), + Tag: "inbound-443", + Sniffing: []byte(`{"enabled":true}`), + } + if !a.Equals(b) { + t.Error("identical InboundConfigs should be equal") + } +} + +func TestInboundConfig_Equals_DifferentPort(t *testing.T) { + a := &InboundConfig{Port: 443, Protocol: "vless"} + b := &InboundConfig{Port: 8443, Protocol: "vless"} + if a.Equals(b) { + t.Error("InboundConfigs with different ports should not be equal") + } +} + +func TestInboundConfig_Equals_DifferentProtocol(t *testing.T) { + a := &InboundConfig{Port: 443, Protocol: "vless"} + b := &InboundConfig{Port: 443, Protocol: "trojan"} + if a.Equals(b) { + t.Error("InboundConfigs with different protocols should not be equal") + } +} + +func TestInboundConfig_Equals_DifferentTag(t *testing.T) { + a := &InboundConfig{Port: 443, Protocol: "vless", Tag: "tag-a"} + b := &InboundConfig{Port: 443, Protocol: "vless", Tag: "tag-b"} + if a.Equals(b) { + t.Error("InboundConfigs with different tags should not be equal") + } +} + +func TestInboundConfig_Equals_NilRawMessages(t *testing.T) { + a := &InboundConfig{Port: 443, Protocol: "vless", Listen: nil, Settings: nil} + b := &InboundConfig{Port: 443, Protocol: "vless", Listen: nil, Settings: nil} + if !a.Equals(b) { + t.Error("InboundConfigs with nil RawMessages should be equal") + } +} + +func TestInboundConfig_Equals_DifferentListen(t *testing.T) { + a := &InboundConfig{Listen: []byte(`"0.0.0.0"`), Port: 443} + b := &InboundConfig{Listen: []byte(`"127.0.0.1"`), Port: 443} + if a.Equals(b) { + t.Error("InboundConfigs with different Listen should not be equal") + } +} + +func TestConfig_Equals_Equal(t *testing.T) { + a := &Config{ + LogConfig: []byte(`{"loglevel":"info"}`), + RouterConfig: []byte(`{}`), + InboundConfigs: []InboundConfig{ + {Port: 443, Protocol: "vless"}, + }, + } + b := &Config{ + LogConfig: []byte(`{"loglevel":"info"}`), + RouterConfig: []byte(`{}`), + InboundConfigs: []InboundConfig{ + {Port: 443, Protocol: "vless"}, + }, + } + if !a.Equals(b) { + t.Error("identical Configs should be equal") + } +} + +func TestConfig_Equals_DifferentInboundCount(t *testing.T) { + a := &Config{ + InboundConfigs: []InboundConfig{{Port: 443}}, + } + b := &Config{ + InboundConfigs: []InboundConfig{}, + } + if a.Equals(b) { + t.Error("Configs with different inbound counts should not be equal") + } +} + +func TestConfig_Equals_DifferentLogConfig(t *testing.T) { + a := &Config{LogConfig: []byte(`{"loglevel":"info"}`)} + b := &Config{LogConfig: []byte(`{"loglevel":"debug"}`)} + if a.Equals(b) { + t.Error("Configs with different LogConfig should not be equal") + } +} + +func TestConfig_Equals_EmptyConfigs(t *testing.T) { + a := &Config{} + b := &Config{} + if !a.Equals(b) { + t.Error("two empty Configs should be equal") + } +} + +func TestConfig_Equals_NilVsEmpty(t *testing.T) { + a := &Config{} + b := &Config{InboundConfigs: []InboundConfig{}} + if !a.Equals(b) { + t.Error("nil and empty slice InboundConfigs should be equal") + } +} diff --git a/xray/log_writer_test.go b/xray/log_writer_test.go new file mode 100644 index 00000000..7578c1c5 --- /dev/null +++ b/xray/log_writer_test.go @@ -0,0 +1,91 @@ +package xray + +import ( + "testing" + + "github.com/mhsanaei/3x-ui/v2/logger" + "github.com/op/go-logging" +) + +func init() { + // Initialize logger for tests that use LogWriter (which calls logger.Debug/Error/etc.) + logger.InitLogger(logging.DEBUG) +} + +func TestNewLogWriter(t *testing.T) { + lw := NewLogWriter() + if lw == nil { + t.Fatal("NewLogWriter should not return nil") + } +} + +func TestLogWriter_Write_CrashDetection(t *testing.T) { + lw := NewLogWriter() + msg := []byte("panic: runtime error: index out of range") + n, err := lw.Write(msg) + if err != nil { + t.Fatalf("Write should not return error: %v", err) + } + if n != len(msg) { + t.Errorf("Write returned %d, expected %d", n, len(msg)) + } +} + +func TestLogWriter_Write_FatalError(t *testing.T) { + lw := NewLogWriter() + msg := []byte("fatal error: concurrent map writes") + n, err := lw.Write(msg) + if err != nil { + t.Fatalf("Write should not return error: %v", err) + } + if n != len(msg) { + t.Errorf("Write returned %d, expected %d", n, len(msg)) + } +} + +func TestLogWriter_Write_Exception(t *testing.T) { + lw := NewLogWriter() + msg := []byte("unhandled exception occurred") + n, err := lw.Write(msg) + if err != nil { + t.Fatalf("Write should not return error: %v", err) + } + if n != len(msg) { + t.Errorf("Write returned %d, expected %d", n, len(msg)) + } +} + +func TestLogWriter_Write_EmptyMessage(t *testing.T) { + lw := NewLogWriter() + n, err := lw.Write([]byte("")) + if err != nil { + t.Fatalf("Write should not error: %v", err) + } + if n != 0 { + t.Errorf("Write returned %d, expected 0", n) + } +} + +func TestLogWriter_Write_TLSErrorSuppressed(t *testing.T) { + lw := NewLogWriter() + msg := []byte("some tls handshake error occurred") + n, err := lw.Write(msg) + if err != nil { + t.Fatalf("Write should not return error: %v", err) + } + if n != len(msg) { + t.Errorf("Write returned %d, expected %d", n, len(msg)) + } +} + +func TestLogWriter_Write_FailedKeyword(t *testing.T) { + lw := NewLogWriter() + msg := []byte("connection failed to remote") + n, err := lw.Write(msg) + if err != nil { + t.Fatalf("Write should not return error: %v", err) + } + if n != len(msg) { + t.Errorf("Write returned %d, expected %d", n, len(msg)) + } +}