test: add unit tests for critical modules and fix flaky redirect middleware

- Add tests for config, database, model, util/common, util/crypto,
  util/random, web/middleware, web/service, and xray packages
- Fix redirect middleware using slice instead of map to guarantee
  deterministic longest-prefix-first matching order
This commit is contained in:
Sora39831 2026-04-03 08:44:51 +08:00
parent d43f8683b0
commit 09f84782b0
13 changed files with 1674 additions and 9 deletions

131
config/config_test.go Normal file
View file

@ -0,0 +1,131 @@
package config
import (
"strings"
"testing"
)
func TestGetVersion(t *testing.T) {
v := GetVersion()
if v == "" {
// version file might be empty in test, that's ok
t.Log("version is empty (expected in test environment)")
}
}
func TestGetName(t *testing.T) {
n := GetName()
if n == "" {
t.Fatal("name should not be empty")
}
if strings.TrimSpace(n) != n {
t.Error("name should be trimmed")
}
}
func TestIsDebugDefault(t *testing.T) {
t.Setenv("XUI_DEBUG", "")
if IsDebug() {
t.Error("IsDebug should return false by default")
}
}
func TestIsDebugTrue(t *testing.T) {
t.Setenv("XUI_DEBUG", "true")
if !IsDebug() {
t.Error("IsDebug should return true when XUI_DEBUG=true")
}
}
func TestIsDebugFalse(t *testing.T) {
t.Setenv("XUI_DEBUG", "false")
if IsDebug() {
t.Error("IsDebug should return false when XUI_DEBUG=false")
}
}
func TestGetLogLevelDefault(t *testing.T) {
t.Setenv("XUI_DEBUG", "")
t.Setenv("XUI_LOG_LEVEL", "")
if GetLogLevel() != Info {
t.Errorf("default log level should be Info, got %s", GetLogLevel())
}
}
func TestGetLogLevelDebug(t *testing.T) {
t.Setenv("XUI_DEBUG", "true")
if GetLogLevel() != Debug {
t.Errorf("log level should be Debug when XUI_DEBUG=true, got %s", GetLogLevel())
}
}
func TestGetLogLevelCustom(t *testing.T) {
t.Setenv("XUI_DEBUG", "")
t.Setenv("XUI_LOG_LEVEL", "warning")
if GetLogLevel() != Warning {
t.Errorf("log level should be Warning, got %s", GetLogLevel())
}
}
func TestGetBinFolderPathDefault(t *testing.T) {
t.Setenv("XUI_BIN_FOLDER", "")
if GetBinFolderPath() != "bin" {
t.Errorf("default bin folder should be 'bin', got %s", GetBinFolderPath())
}
}
func TestGetBinFolderPathCustom(t *testing.T) {
t.Setenv("XUI_BIN_FOLDER", "/custom/bin")
if GetBinFolderPath() != "/custom/bin" {
t.Errorf("bin folder should be '/custom/bin', got %s", GetBinFolderPath())
}
}
func TestGetDBFolderPathDefault(t *testing.T) {
t.Setenv("XUI_DB_FOLDER", "")
folder := GetDBFolderPath()
// On Linux without env var, should be "/etc/x-ui"
if folder != "/etc/x-ui" {
t.Errorf("default DB folder should be '/etc/x-ui', got %s", folder)
}
}
func TestGetDBFolderPathCustom(t *testing.T) {
t.Setenv("XUI_DB_FOLDER", "/tmp/test-db")
if GetDBFolderPath() != "/tmp/test-db" {
t.Errorf("DB folder should be '/tmp/test-db', got %s", GetDBFolderPath())
}
}
func TestGetDBPath(t *testing.T) {
t.Setenv("XUI_DB_FOLDER", "/tmp/test")
dbPath := GetDBPath()
expected := "/tmp/test/" + GetName() + ".db"
if dbPath != expected {
t.Errorf("GetDBPath() = %q, want %q", dbPath, expected)
}
}
func TestGetSettingPath(t *testing.T) {
t.Setenv("XUI_DB_FOLDER", "/tmp/test")
settingPath := GetSettingPath()
expected := "/tmp/test/" + GetName() + ".json"
if settingPath != expected {
t.Errorf("GetSettingPath() = %q, want %q", settingPath, expected)
}
}
func TestGetLogFolderDefault(t *testing.T) {
t.Setenv("XUI_LOG_FOLDER", "")
folder := GetLogFolder()
if folder != "/var/log/x-ui" {
t.Errorf("default log folder should be '/var/log/x-ui', got %s", folder)
}
}
func TestGetLogFolderCustom(t *testing.T) {
t.Setenv("XUI_LOG_FOLDER", "/custom/logs")
if GetLogFolder() != "/custom/logs" {
t.Errorf("log folder should be '/custom/logs', got %s", GetLogFolder())
}
}

214
database/db_test.go Normal file
View file

@ -0,0 +1,214 @@
package database
import (
"bytes"
"os"
"path/filepath"
"testing"
"github.com/mhsanaei/3x-ui/v2/database/model"
)
func setupTestDB(t *testing.T) {
t.Helper()
tmpDir := t.TempDir()
t.Setenv("XUI_DEBUG", "")
dbPath := filepath.Join(tmpDir, "test.db")
if err := InitDB(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
t.Cleanup(func() {
CloseDB()
})
}
func TestIsSQLiteDB_ValidFile(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "valid.db")
if err := InitDB(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
defer CloseDB()
f, err := os.Open(dbPath)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
defer f.Close()
ok, err := IsSQLiteDB(f)
if err != nil {
t.Fatalf("IsSQLiteDB error: %v", err)
}
if !ok {
t.Error("IsSQLiteDB should return true for a valid SQLite file")
}
}
func TestIsSQLiteDB_InvalidFile(t *testing.T) {
tmpDir := t.TempDir()
notADB := filepath.Join(tmpDir, "notdb.txt")
if err := os.WriteFile(notADB, []byte("this is not a database"), 0644); err != nil {
t.Fatal(err)
}
f, err := os.Open(notADB)
if err != nil {
t.Fatal(err)
}
defer f.Close()
ok, err := IsSQLiteDB(f)
if err != nil {
t.Fatalf("IsSQLiteDB error: %v", err)
}
if ok {
t.Error("IsSQLiteDB should return false for a non-SQLite file")
}
}
func TestIsSQLiteDB_EmptyFile(t *testing.T) {
tmpDir := t.TempDir()
empty := filepath.Join(tmpDir, "empty.db")
if err := os.WriteFile(empty, []byte{}, 0644); err != nil {
t.Fatal(err)
}
f, err := os.Open(empty)
if err != nil {
t.Fatal(err)
}
defer f.Close()
ok, err := IsSQLiteDB(f)
// Empty file returns EOF since there aren't enough bytes to read
if err == nil && ok {
t.Error("IsSQLiteDB should return false for an empty file")
}
}
func TestIsSQLiteDB_WrongContent(t *testing.T) {
// File with 16 bytes (matching SQLite header length) but wrong signature content
r := bytes.NewReader([]byte("SQLite for !!"))
ok, err := IsSQLiteDB(r)
if err != nil {
t.Fatalf("IsSQLiteDB error: %v", err)
}
if ok {
t.Error("IsSQLiteDB should return false for wrong signature content")
}
}
func TestInitDB_CreatesTables(t *testing.T) {
setupTestDB(t)
// Verify all tables exist by querying them
tables := []string{"users", "inbounds", "outbound_traffics", "settings", "inbound_client_ips", "client_traffics", "history_of_seeders"}
for _, table := range tables {
var count int64
if err := db.Table(table).Count(&count).Error; err != nil {
t.Errorf("table %q should exist but got error: %v", table, err)
}
}
}
func TestInitDB_CreatesDefaultUser(t *testing.T) {
setupTestDB(t)
var user model.User
if err := db.First(&user).Error; err != nil {
t.Fatalf("should have a default user: %v", err)
}
if user.Username != "admin" {
t.Errorf("default username should be 'admin', got %q", user.Username)
}
if user.Role != "admin" {
t.Errorf("default role should be 'admin', got %q", user.Role)
}
// Password should be a bcrypt hash, not plaintext
if user.Password == "admin" {
t.Error("default password should be hashed, not plaintext")
}
}
func TestInitDB_Idempotent(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DEBUG", "")
dbPath := filepath.Join(tmpDir, "idempotent.db")
// First init
if err := InitDB(dbPath); err != nil {
t.Fatalf("first InitDB failed: %v", err)
}
CloseDB()
// Second init on the same file should not fail
if err := InitDB(dbPath); err != nil {
t.Fatalf("second InitDB failed: %v", err)
}
defer CloseDB()
// Should still have exactly one default user
var count int64
db.Model(&model.User{}).Count(&count)
if count != 1 {
t.Errorf("expected 1 user after second init, got %d", count)
}
}
func TestValidateSQLiteDB_ValidDB(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "valid.db")
if err := InitDB(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
CloseDB()
if err := ValidateSQLiteDB(dbPath); err != nil {
t.Errorf("ValidateSQLiteDB should pass for valid DB: %v", err)
}
}
func TestValidateSQLiteDB_NonExistent(t *testing.T) {
err := ValidateSQLiteDB("/tmp/does-not-exist-12345.db")
if err == nil {
t.Error("ValidateSQLiteDB should fail for non-existent file")
}
}
func TestValidateSQLiteDB_CorruptDB(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "corrupt.db")
// Write garbage that looks like SQLite header but is corrupt
garbage := make([]byte, 4096)
copy(garbage, []byte("SQLite format 3\x00"))
if err := os.WriteFile(dbPath, garbage, 0644); err != nil {
t.Fatal(err)
}
err := ValidateSQLiteDB(dbPath)
if err == nil {
t.Error("ValidateSQLiteDB should fail for corrupt DB")
}
}
func TestIsNotFound(t *testing.T) {
if IsNotFound(nil) {
t.Error("IsNotFound should return false for nil")
}
}
func TestInitUser_OnlyOnce(t *testing.T) {
setupTestDB(t)
// initUser should not create a second user when table is not empty
if err := initUser(); err != nil {
t.Fatalf("initUser error: %v", err)
}
var count int64
db.Model(&model.User{}).Count(&count)
if count != 1 {
t.Errorf("expected 1 user, got %d", count)
}
}

View file

@ -0,0 +1,61 @@
package model
import (
"testing"
)
func TestGenXrayInboundConfig_EmptyListen(t *testing.T) {
in := &Inbound{
Listen: "",
Port: 443,
Protocol: VLESS,
Settings: `{"clients":[]}`,
Tag: "test-inbound",
Sniffing: `{"enabled":true}`,
}
cfg := in.GenXrayInboundConfig()
if cfg == nil {
t.Fatal("GenXrayInboundConfig should not return nil")
}
// Empty listen should default to 0.0.0.0
expected := `"0.0.0.0"`
if string(cfg.Listen) != expected {
t.Errorf("Listen should default to %s, got %s", expected, string(cfg.Listen))
}
if cfg.Port != 443 {
t.Errorf("Port should be 443, got %d", cfg.Port)
}
if cfg.Protocol != "vless" {
t.Errorf("Protocol should be vless, got %q", cfg.Protocol)
}
if cfg.Tag != "test-inbound" {
t.Errorf("Tag should be test-inbound, got %q", cfg.Tag)
}
}
func TestGenXrayInboundConfig_CustomListen(t *testing.T) {
in := &Inbound{
Listen: "127.0.0.1",
Port: 8080,
Protocol: VMESS,
Tag: "custom",
}
cfg := in.GenXrayInboundConfig()
expected := `"127.0.0.1"`
if string(cfg.Listen) != expected {
t.Errorf("Listen should be %s, got %s", expected, string(cfg.Listen))
}
}
func TestGenXrayInboundConfig_EmptySettings(t *testing.T) {
in := &Inbound{
Port: 443,
Protocol: Trojan,
}
cfg := in.GenXrayInboundConfig()
if cfg == nil {
t.Fatal("GenXrayInboundConfig should not return nil")
}
// Empty string Settings produces a nil RawMessage since json_util.RawMessage("") may be nil
// Just verify no panic occurred
}

162
util/common/common_test.go Normal file
View file

@ -0,0 +1,162 @@
package common
import (
"errors"
"strings"
"testing"
)
func TestNewErrorf(t *testing.T) {
err := NewErrorf("invalid port: %d", 8080)
if err == nil {
t.Fatal("NewErrorf should return non-nil error")
}
expected := "invalid port: 8080"
if err.Error() != expected {
t.Errorf("NewErrorf returned %q, expected %q", err.Error(), expected)
}
}
func TestNewError(t *testing.T) {
err := NewError("something", " went wrong")
if err == nil {
t.Fatal("NewError should return non-nil error")
}
if !strings.Contains(err.Error(), "something") {
t.Errorf("NewError should contain 'something', got %q", err.Error())
}
}
func TestRecoverWithoutPanic(t *testing.T) {
recovered := Recover("")
if recovered != nil {
t.Errorf("Recover should return nil when no panic occurred, got %v", recovered)
}
}
func TestFormatTrafficBytes(t *testing.T) {
tests := []struct {
input int64
expected string
}{
{0, "0.00B"},
{1, "1.00B"},
{512, "512.00B"},
{1023, "1023.00B"},
}
for _, tt := range tests {
result := FormatTraffic(tt.input)
if result != tt.expected {
t.Errorf("FormatTraffic(%d) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestFormatTrafficKB(t *testing.T) {
tests := []struct {
input int64
expected string
}{
{1024, "1.00KB"},
{1536, "1.50KB"},
{2048, "2.00KB"},
}
for _, tt := range tests {
result := FormatTraffic(tt.input)
if result != tt.expected {
t.Errorf("FormatTraffic(%d) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestFormatTrafficMB(t *testing.T) {
result := FormatTraffic(1048576) // 1 MB
expected := "1.00MB"
if result != expected {
t.Errorf("FormatTraffic(1048576) = %q, want %q", result, expected)
}
}
func TestFormatTrafficGB(t *testing.T) {
result := FormatTraffic(1073741824) // 1 GB
expected := "1.00GB"
if result != expected {
t.Errorf("FormatTraffic(1073741824) = %q, want %q", result, expected)
}
}
func TestFormatTrafficTB(t *testing.T) {
result := FormatTraffic(1099511627776) // 1 TB
expected := "1.00TB"
if result != expected {
t.Errorf("FormatTraffic(1099511627776) = %q, want %q", result, expected)
}
}
func TestFormatTrafficPB(t *testing.T) {
result := FormatTraffic(1125899906842624) // 1 PB
expected := "1.00PB"
if result != expected {
t.Errorf("FormatTraffic(1125899906842624) = %q, want %q", result, expected)
}
}
func TestFormatTrafficLargePB(t *testing.T) {
// Value exceeding PB should stay in PB
result := FormatTraffic(11258999068426240) // 10 PB
if !strings.HasSuffix(result, "PB") {
t.Errorf("FormatTraffic should cap at PB, got %q", result)
}
}
func TestCombineAllNil(t *testing.T) {
err := Combine(nil, nil, nil)
if err != nil {
t.Errorf("Combine(nil, nil, nil) should return nil, got %v", err)
}
}
func TestCombineNoArgs(t *testing.T) {
err := Combine()
if err != nil {
t.Errorf("Combine() should return nil, got %v", err)
}
}
func TestCombineSingleError(t *testing.T) {
input := errors.New("test error")
err := Combine(input)
if err == nil {
t.Fatal("Combine should return non-nil when an error is present")
}
if !strings.Contains(err.Error(), "test error") {
t.Errorf("Combine should contain 'test error', got %q", err.Error())
}
}
func TestCombineMultipleErrors(t *testing.T) {
err1 := errors.New("error 1")
err2 := errors.New("error 2")
combined := Combine(err1, nil, err2)
if combined == nil {
t.Fatal("Combine should return non-nil when errors are present")
}
s := combined.Error()
if !strings.Contains(s, "error 1") {
t.Errorf("Combined error should contain 'error 1', got %q", s)
}
if !strings.Contains(s, "error 2") {
t.Errorf("Combined error should contain 'error 2', got %q", s)
}
}
func TestCombineFiltersNils(t *testing.T) {
err1 := errors.New("real error")
combined := Combine(nil, err1, nil)
if combined == nil {
t.Fatal("Combine should return non-nil when at least one error is present")
}
if !strings.Contains(combined.Error(), "real error") {
t.Errorf("Combined error should contain 'real error', got %q", combined.Error())
}
}

View file

@ -0,0 +1,72 @@
package crypto
import (
"strings"
"testing"
)
func TestHashPasswordAsBcrypt(t *testing.T) {
hash, err := HashPasswordAsBcrypt("password123")
if err != nil {
t.Fatalf("HashPasswordAsBcrypt failed: %v", err)
}
if hash == "" {
t.Fatal("hash should not be empty")
}
if hash == "password123" {
t.Fatal("hash should not equal the plaintext password")
}
if !strings.HasPrefix(hash, "$2a$") && !strings.HasPrefix(hash, "$2b$") {
t.Fatalf("hash should have bcrypt prefix, got: %s", hash[:4])
}
}
func TestHashAndCheckRoundTrip(t *testing.T) {
passwords := []string{
"password123",
"",
"very-long-password-with-special-chars-!@#$%^&*()",
"unicode-密码-test",
}
for _, pw := range passwords {
hash, err := HashPasswordAsBcrypt(pw)
if err != nil {
t.Fatalf("HashPasswordAsBcrypt(%q) failed: %v", pw, err)
}
if !CheckPasswordHash(hash, pw) {
t.Errorf("CheckPasswordHash should return true for correct password %q", pw)
}
}
}
func TestCheckPasswordHashWrongPassword(t *testing.T) {
hash, err := HashPasswordAsBcrypt("correct-password")
if err != nil {
t.Fatalf("HashPasswordAsBcrypt failed: %v", err)
}
if CheckPasswordHash(hash, "wrong-password") {
t.Error("CheckPasswordHash should return false for wrong password")
}
}
func TestCheckPasswordHashInvalidHash(t *testing.T) {
if CheckPasswordHash("not-a-valid-hash", "password") {
t.Error("CheckPasswordHash should return false for invalid hash")
}
}
func TestDifferentPasswordsProduceDifferentHashes(t *testing.T) {
hash1, _ := HashPasswordAsBcrypt("password1")
hash2, _ := HashPasswordAsBcrypt("password2")
if hash1 == hash2 {
t.Error("different passwords should produce different hashes")
}
}
func TestSamePasswordProducesDifferentHashes(t *testing.T) {
hash1, _ := HashPasswordAsBcrypt("same-password")
hash2, _ := HashPasswordAsBcrypt("same-password")
if hash1 == hash2 {
t.Error("bcrypt should use different salts, producing different hashes for same password")
}
}

View file

@ -0,0 +1,67 @@
package random
import (
"testing"
)
func TestSeqLength(t *testing.T) {
lengths := []int{0, 1, 10, 32, 64, 128}
for _, n := range lengths {
s := Seq(n)
if len([]rune(s)) != n {
t.Errorf("Seq(%d) returned string of length %d", n, len([]rune(s)))
}
}
}
func TestSeqCharset(t *testing.T) {
allChars := make(map[rune]bool)
for _, r := range allSeq {
allChars[r] = true
}
s := Seq(1000)
for i, r := range s {
if !allChars[r] {
t.Errorf("Seq produced invalid character %q at index %d", r, i)
}
}
}
func TestSeqUniqueness(t *testing.T) {
// Generate several strings and check they're not all identical
seen := make(map[string]bool)
for range 10 {
seen[Seq(32)] = true
}
if len(seen) < 2 {
t.Error("Seq(32) produced identical strings across 10 calls")
}
}
func TestSeqEmpty(t *testing.T) {
s := Seq(0)
if s != "" {
t.Errorf("Seq(0) should return empty string, got %q", s)
}
}
func TestNumRange(t *testing.T) {
for _, n := range []int{1, 5, 10, 100, 1000} {
for range 100 {
r := Num(n)
if r < 0 || r >= n {
t.Errorf("Num(%d) returned %d, expected [0, %d)", n, r, n)
}
}
}
}
func TestNumOne(t *testing.T) {
for range 50 {
r := Num(1)
if r != 0 {
t.Errorf("Num(1) should always return 0, got %d", r)
}
}
}

View file

@ -0,0 +1,320 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestRedirectMiddleware_XUIToPanel(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/*path", func(c *gin.Context) {
c.String(http.StatusOK, "panel")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/xui/settings", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/panel/settings" {
t.Errorf("expected redirect to /panel/settings, got %q", loc)
}
}
func TestRedirectMiddleware_XUIAPIToPanelAPI(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/api/*path", func(c *gin.Context) {
c.String(http.StatusOK, "api")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/xui/API/inbounds", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/panel/api/inbounds" {
t.Errorf("expected redirect to /panel/api/inbounds, got %q", loc)
}
}
func TestRedirectMiddleware_PanelAPICase(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/api/*path", func(c *gin.Context) {
c.String(http.StatusOK, "api")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/panel/API/list", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/panel/api/list" {
t.Errorf("expected redirect to /panel/api/list, got %q", loc)
}
}
func TestRedirectMiddleware_NoRedirect(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/settings", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/panel/settings", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestRedirectMiddleware_WithBasePath(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/base/"))
r.GET("/base/panel/*path", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/base/xui/settings", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/base/panel/settings" {
t.Errorf("expected redirect to /base/panel/settings, got %q", loc)
}
}
func TestDomainValidatorMiddleware_MatchingDomain(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "example.com"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestDomainValidatorMiddleware_MatchingDomainWithPort(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "example.com:8443"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200 for matching domain with port, got %d", w.Code)
}
}
func TestDomainValidatorMiddleware_NonMatchingDomain(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "evil.com"
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403, got %d", w.Code)
}
}
func TestDomainValidatorMiddleware_Subdomain(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "sub.example.com"
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403 for subdomain, got %d", w.Code)
}
}
func TestRateLimitMiddleware_FirstRequest(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(5, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200 for first request, got %d", w.Code)
}
}
func TestRateLimitMiddleware_WithinLimit(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(3, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
for i := range 3 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.2:12345"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("request %d: expected 200, got %d", i+1, w.Code)
}
}
}
func TestRateLimitMiddleware_ExceedsLimit(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(2, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// First 2 should pass
for range 2 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.3:12345"
r.ServeHTTP(w, req)
}
// 3rd should be rate limited
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.3:12345"
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", w.Code)
}
}
func TestRateLimitMiddleware_XRealIP(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(2, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
for range 2 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.1")
r.ServeHTTP(w, req)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.1")
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429 with X-Real-IP, got %d", w.Code)
}
}
func TestRateLimitMiddleware_XForwardedFor(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(2, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
for range 2 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3")
r.ServeHTTP(w, req)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3")
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429 with X-Forwarded-For, got %d", w.Code)
}
}
func TestRateLimitMiddleware_DifferentIPsIndependent(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(1, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Exhaust limit for IP 1
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.10")
r.ServeHTTP(w, req)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.10")
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("IP 1 second request should be 429, got %d", w.Code)
}
// IP 2 should still be allowed
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.20")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("different IP should get 200, got %d", w.Code)
}
}

View file

@ -11,17 +11,20 @@ import (
// It provides backward compatibility by redirecting old '/xui' paths to new '/panel' paths, // It provides backward compatibility by redirecting old '/xui' paths to new '/panel' paths,
// including API endpoints. The middleware performs permanent redirects (301) for SEO purposes. // including API endpoints. The middleware performs permanent redirects (301) for SEO purposes.
func RedirectMiddleware(basePath string) gin.HandlerFunc { func RedirectMiddleware(basePath string) gin.HandlerFunc {
return func(c *gin.Context) { // Use a slice to guarantee longest-prefix-first matching order.
// Redirect from old '/xui' path to '/panel' // A map would have nondeterministic iteration, causing "/xui/API" to
redirects := map[string]string{ // sometimes match the shorter "/xui" rule instead.
"panel/API": "panel/api", redirects := []struct{ from, to string }{
"xui/API": "panel/api", {"panel/API", "panel/api"},
"xui": "panel", {"xui/API", "panel/api"},
} {"xui", "panel"},
}
return func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
for from, to := range redirects { for _, r := range redirects {
from, to = basePath+from, basePath+to from := basePath + r.from
to := basePath + r.to
if strings.HasPrefix(path, from) { if strings.HasPrefix(path, from) {
newPath := to + path[len(from):] newPath := to + path[len(from):]

100
web/service/server_test.go Normal file
View file

@ -0,0 +1,100 @@
package service
import (
"testing"
)
func TestIsValidGeofileName_Valid(t *testing.T) {
svc := &ServerService{}
valid := []string{
"geoip.dat",
"geosite.dat",
"geoip_IR.dat",
"geoip_RU.dat",
"geosite_IR.dat",
"custom-file_v2.dat",
}
for _, name := range valid {
if !svc.IsValidGeofileName(name) {
t.Errorf("IsValidGeofileName(%q) should return true", name)
}
}
}
func TestIsValidGeofileName_PathTraversal(t *testing.T) {
svc := &ServerService{}
invalid := []string{
"../geoip.dat",
"../../etc/passwd",
"subdir/geoip.dat",
"geoip.dat/../../../etc",
"..\\geoip.dat",
}
for _, name := range invalid {
if svc.IsValidGeofileName(name) {
t.Errorf("IsValidGeofileName(%q) should return false (path traversal)", name)
}
}
}
func TestIsValidGeofileName_Empty(t *testing.T) {
svc := &ServerService{}
if svc.IsValidGeofileName("") {
t.Error("IsValidGeofileName(\"\") should return false")
}
}
func TestIsValidGeofileName_NoDatExtension(t *testing.T) {
svc := &ServerService{}
invalid := []string{
"geoip.txt",
"geosite",
"file.exe",
"script.sh",
}
for _, name := range invalid {
if svc.IsValidGeofileName(name) {
t.Errorf("IsValidGeofileName(%q) should return false (no .dat extension)", name)
}
}
}
func TestIsValidGeofileName_SpecialChars(t *testing.T) {
svc := &ServerService{}
invalid := []string{
"geoip$.dat",
"geoip!.dat",
"geoip;.dat",
"geoip .dat",
"geoip@attack.dat",
}
for _, name := range invalid {
if svc.IsValidGeofileName(name) {
t.Errorf("IsValidGeofileName(%q) should return false (special chars)", name)
}
}
}
func TestLogEntryContains(t *testing.T) {
tests := []struct {
line string
suffixes []string
want bool
}{
// The implementation checks strings.Contains(line, sfx+"]")
{"line with freedom]", []string{"freedom"}, true},
{"line with blackhole]", []string{"blackhole"}, true},
{"freedom outbound", []string{"freedom"}, false},
{"blackhole outbound", []string{"blackhole"}, false},
{"freedom outbound", []string{"blackhole"}, false},
{"some log line", []string{}, false},
{"line with freedom] and blackhole]", []string{"freedom", "blackhole"}, true},
{"line with freedom] and blackhole]", []string{"other"}, false},
}
for _, tt := range tests {
got := logEntryContains(tt.line, tt.suffixes)
if got != tt.want {
t.Errorf("logEntryContains(%q, %v) = %v, want %v", tt.line, tt.suffixes, got, tt.want)
}
}
}

147
web/service/user_test.go Normal file
View file

@ -0,0 +1,147 @@
package service
import (
"os"
"path/filepath"
"testing"
"github.com/mhsanaei/3x-ui/v2/database"
"github.com/mhsanaei/3x-ui/v2/database/model"
"github.com/mhsanaei/3x-ui/v2/util/crypto"
)
func setupTestDB(t *testing.T) {
t.Helper()
tmpDir := t.TempDir()
t.Setenv("XUI_DEBUG", "")
t.Setenv("XUI_DB_FOLDER", tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
if err := database.InitDB(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
t.Cleanup(func() {
database.CloseDB()
})
}
func TestGetFirstUser(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
user, err := svc.GetFirstUser()
if err != nil {
t.Fatalf("GetFirstUser error: %v", err)
}
if user.Username != "admin" {
t.Errorf("expected username 'admin', got %q", user.Username)
}
}
func TestCheckUser_ValidCredentials(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
user, err := svc.CheckUser("admin", "admin", "")
if err != nil {
t.Fatalf("CheckUser error: %v", err)
}
if user.Username != "admin" {
t.Errorf("expected username 'admin', got %q", user.Username)
}
}
func TestCheckUser_WrongPassword(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
_, err := svc.CheckUser("admin", "wrongpassword", "")
if err == nil {
t.Error("CheckUser should fail with wrong password")
}
}
func TestCheckUser_NonExistentUser(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
_, err := svc.CheckUser("nonexistent", "password", "")
if err == nil {
t.Error("CheckUser should fail for non-existent user")
}
}
func TestUpdateFirstUser(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
err := svc.UpdateFirstUser("newadmin", "newpassword123")
if err != nil {
t.Fatalf("UpdateFirstUser error: %v", err)
}
// Verify login with new credentials
user, err := svc.CheckUser("newadmin", "newpassword123", "")
if err != nil {
t.Fatalf("CheckUser with new credentials error: %v", err)
}
if user.Username != "newadmin" {
t.Errorf("expected username 'newadmin', got %q", user.Username)
}
}
func TestUpdateFirstUser_EmptyUsername(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
err := svc.UpdateFirstUser("", "password")
if err == nil {
t.Error("UpdateFirstUser should fail with empty username")
}
}
func TestUpdateFirstUser_EmptyPassword(t *testing.T) {
setupTestDB(t)
svc := &UserService{}
err := svc.UpdateFirstUser("admin", "")
if err == nil {
t.Error("UpdateFirstUser should fail with empty password")
}
}
func TestUpdateFirstUser_CreateWhenNone(t *testing.T) {
// Use a fresh temp dir so no users table data exists
tmpDir := t.TempDir()
os.Setenv("XUI_DEBUG", "")
os.Setenv("XUI_DB_FOLDER", tmpDir)
defer func() {
os.Unsetenv("XUI_DEBUG")
os.Unsetenv("XUI_DB_FOLDER")
}()
dbPath := filepath.Join(tmpDir, "empty.db")
if err := database.InitDB(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
defer database.CloseDB()
// Delete all users to simulate empty table
database.GetDB().Where("1 = 1").Delete(&model.User{})
svc := &UserService{}
err := svc.UpdateFirstUser("firstadmin", "firstpass")
if err != nil {
t.Fatalf("UpdateFirstUser should create user when table is empty: %v", err)
}
user, err := svc.GetFirstUser()
if err != nil {
t.Fatalf("GetFirstUser error: %v", err)
}
if user.Username != "firstadmin" {
t.Errorf("expected username 'firstadmin', got %q", user.Username)
}
if !crypto.CheckPasswordHash(user.Password, "firstpass") {
t.Error("password hash should match 'firstpass'")
}
}

172
xray/api_test.go Normal file
View file

@ -0,0 +1,172 @@
package xray
import (
"testing"
)
func TestProcessTraffic_Inbound(t *testing.T) {
matches := []string{
"inbound>>>vmess-in>>>traffic>>>uplink",
"inbound",
"vmess-in",
"uplink",
}
trafficMap := make(map[string]*Traffic)
processTraffic(matches, 1024, trafficMap)
tr, ok := trafficMap["vmess-in"]
if !ok {
t.Fatal("should have vmess-in entry")
}
if !tr.IsInbound {
t.Error("should be inbound")
}
if tr.IsOutbound {
t.Error("should not be outbound")
}
if tr.Tag != "vmess-in" {
t.Errorf("tag should be vmess-in, got %q", tr.Tag)
}
if tr.Up != 1024 {
t.Errorf("up should be 1024, got %d", tr.Up)
}
if tr.Down != 0 {
t.Errorf("down should be 0, got %d", tr.Down)
}
}
func TestProcessTraffic_Outbound(t *testing.T) {
matches := []string{
"outbound>>>direct>>>traffic>>>downlink",
"outbound",
"direct",
"downlink",
}
trafficMap := make(map[string]*Traffic)
processTraffic(matches, 2048, trafficMap)
tr, ok := trafficMap["direct"]
if !ok {
t.Fatal("should have direct entry")
}
if tr.IsOutbound != true {
t.Error("should be outbound")
}
if tr.IsInbound != false {
t.Error("should not be inbound")
}
if tr.Down != 2048 {
t.Errorf("down should be 2048, got %d", tr.Down)
}
}
func TestProcessTraffic_ApiTagSkipped(t *testing.T) {
matches := []string{
"inbound>>>api>>>traffic>>>uplink",
"inbound",
"api",
"uplink",
}
trafficMap := make(map[string]*Traffic)
processTraffic(matches, 1024, trafficMap)
if _, ok := trafficMap["api"]; ok {
t.Error("api tag should be skipped")
}
}
func TestProcessTraffic_Aggregates(t *testing.T) {
trafficMap := make(map[string]*Traffic)
// First: uplink
processTraffic([]string{"", "inbound", "test-tag", "uplink"}, 100, trafficMap)
// Second: downlink on same tag
processTraffic([]string{"", "inbound", "test-tag", "downlink"}, 200, trafficMap)
tr := trafficMap["test-tag"]
if tr.Up != 100 {
t.Errorf("expected up=100, got %d", tr.Up)
}
if tr.Down != 200 {
t.Errorf("expected down=200, got %d", tr.Down)
}
}
func TestProcessClientTraffic(t *testing.T) {
clientMap := make(map[string]*ClientTraffic)
processClientTraffic([]string{"", "user@example.com", "uplink"}, 500, clientMap)
processClientTraffic([]string{"", "user@example.com", "downlink"}, 1500, clientMap)
ct, ok := clientMap["user@example.com"]
if !ok {
t.Fatal("should have client entry")
}
if ct.Email != "user@example.com" {
t.Errorf("email should be user@example.com, got %q", ct.Email)
}
if ct.Up != 500 {
t.Errorf("up should be 500, got %d", ct.Up)
}
if ct.Down != 1500 {
t.Errorf("down should be 1500, got %d", ct.Down)
}
}
func TestProcessClientTraffic_MultipleClients(t *testing.T) {
clientMap := make(map[string]*ClientTraffic)
processClientTraffic([]string{"", "user1@test.com", "uplink"}, 100, clientMap)
processClientTraffic([]string{"", "user2@test.com", "uplink"}, 200, clientMap)
if len(clientMap) != 2 {
t.Errorf("expected 2 clients, got %d", len(clientMap))
}
if clientMap["user1@test.com"].Up != 100 {
t.Error("user1 up mismatch")
}
if clientMap["user2@test.com"].Up != 200 {
t.Error("user2 up mismatch")
}
}
func TestMapToSlice_Empty(t *testing.T) {
m := make(map[string]*Traffic)
result := mapToSlice(m)
if len(result) != 0 {
t.Errorf("expected empty slice, got length %d", len(result))
}
}
func TestMapToSlice_Nil(t *testing.T) {
var m map[string]*Traffic
result := mapToSlice(m)
if len(result) != 0 {
t.Errorf("expected empty slice for nil map, got length %d", len(result))
}
}
func TestMapToSlice_Multiple(t *testing.T) {
m := map[string]*Traffic{
"a": {Tag: "a", Up: 1},
"b": {Tag: "b", Up: 2},
"c": {Tag: "c", Up: 3},
}
result := mapToSlice(m)
if len(result) != 3 {
t.Errorf("expected 3 elements, got %d", len(result))
}
}
func TestXrayAPI_Init_InvalidPort(t *testing.T) {
api := &XrayAPI{}
if err := api.Init(0); err == nil {
t.Error("Init with port 0 should return error")
}
if err := api.Init(-1); err == nil {
t.Error("Init with negative port should return error")
}
if err := api.Init(70000); err == nil {
t.Error("Init with port > 65535 should return error")
}
}

125
xray/config_test.go Normal file
View file

@ -0,0 +1,125 @@
package xray
import (
"testing"
)
func TestInboundConfig_Equals_Equal(t *testing.T) {
a := &InboundConfig{
Listen: []byte(`"0.0.0.0"`),
Port: 443,
Protocol: "vless",
Settings: []byte(`{}`),
StreamSettings: []byte(`{"network":"tcp"}`),
Tag: "inbound-443",
Sniffing: []byte(`{"enabled":true}`),
}
b := &InboundConfig{
Listen: []byte(`"0.0.0.0"`),
Port: 443,
Protocol: "vless",
Settings: []byte(`{}`),
StreamSettings: []byte(`{"network":"tcp"}`),
Tag: "inbound-443",
Sniffing: []byte(`{"enabled":true}`),
}
if !a.Equals(b) {
t.Error("identical InboundConfigs should be equal")
}
}
func TestInboundConfig_Equals_DifferentPort(t *testing.T) {
a := &InboundConfig{Port: 443, Protocol: "vless"}
b := &InboundConfig{Port: 8443, Protocol: "vless"}
if a.Equals(b) {
t.Error("InboundConfigs with different ports should not be equal")
}
}
func TestInboundConfig_Equals_DifferentProtocol(t *testing.T) {
a := &InboundConfig{Port: 443, Protocol: "vless"}
b := &InboundConfig{Port: 443, Protocol: "trojan"}
if a.Equals(b) {
t.Error("InboundConfigs with different protocols should not be equal")
}
}
func TestInboundConfig_Equals_DifferentTag(t *testing.T) {
a := &InboundConfig{Port: 443, Protocol: "vless", Tag: "tag-a"}
b := &InboundConfig{Port: 443, Protocol: "vless", Tag: "tag-b"}
if a.Equals(b) {
t.Error("InboundConfigs with different tags should not be equal")
}
}
func TestInboundConfig_Equals_NilRawMessages(t *testing.T) {
a := &InboundConfig{Port: 443, Protocol: "vless", Listen: nil, Settings: nil}
b := &InboundConfig{Port: 443, Protocol: "vless", Listen: nil, Settings: nil}
if !a.Equals(b) {
t.Error("InboundConfigs with nil RawMessages should be equal")
}
}
func TestInboundConfig_Equals_DifferentListen(t *testing.T) {
a := &InboundConfig{Listen: []byte(`"0.0.0.0"`), Port: 443}
b := &InboundConfig{Listen: []byte(`"127.0.0.1"`), Port: 443}
if a.Equals(b) {
t.Error("InboundConfigs with different Listen should not be equal")
}
}
func TestConfig_Equals_Equal(t *testing.T) {
a := &Config{
LogConfig: []byte(`{"loglevel":"info"}`),
RouterConfig: []byte(`{}`),
InboundConfigs: []InboundConfig{
{Port: 443, Protocol: "vless"},
},
}
b := &Config{
LogConfig: []byte(`{"loglevel":"info"}`),
RouterConfig: []byte(`{}`),
InboundConfigs: []InboundConfig{
{Port: 443, Protocol: "vless"},
},
}
if !a.Equals(b) {
t.Error("identical Configs should be equal")
}
}
func TestConfig_Equals_DifferentInboundCount(t *testing.T) {
a := &Config{
InboundConfigs: []InboundConfig{{Port: 443}},
}
b := &Config{
InboundConfigs: []InboundConfig{},
}
if a.Equals(b) {
t.Error("Configs with different inbound counts should not be equal")
}
}
func TestConfig_Equals_DifferentLogConfig(t *testing.T) {
a := &Config{LogConfig: []byte(`{"loglevel":"info"}`)}
b := &Config{LogConfig: []byte(`{"loglevel":"debug"}`)}
if a.Equals(b) {
t.Error("Configs with different LogConfig should not be equal")
}
}
func TestConfig_Equals_EmptyConfigs(t *testing.T) {
a := &Config{}
b := &Config{}
if !a.Equals(b) {
t.Error("two empty Configs should be equal")
}
}
func TestConfig_Equals_NilVsEmpty(t *testing.T) {
a := &Config{}
b := &Config{InboundConfigs: []InboundConfig{}}
if !a.Equals(b) {
t.Error("nil and empty slice InboundConfigs should be equal")
}
}

91
xray/log_writer_test.go Normal file
View file

@ -0,0 +1,91 @@
package xray
import (
"testing"
"github.com/mhsanaei/3x-ui/v2/logger"
"github.com/op/go-logging"
)
func init() {
// Initialize logger for tests that use LogWriter (which calls logger.Debug/Error/etc.)
logger.InitLogger(logging.DEBUG)
}
func TestNewLogWriter(t *testing.T) {
lw := NewLogWriter()
if lw == nil {
t.Fatal("NewLogWriter should not return nil")
}
}
func TestLogWriter_Write_CrashDetection(t *testing.T) {
lw := NewLogWriter()
msg := []byte("panic: runtime error: index out of range")
n, err := lw.Write(msg)
if err != nil {
t.Fatalf("Write should not return error: %v", err)
}
if n != len(msg) {
t.Errorf("Write returned %d, expected %d", n, len(msg))
}
}
func TestLogWriter_Write_FatalError(t *testing.T) {
lw := NewLogWriter()
msg := []byte("fatal error: concurrent map writes")
n, err := lw.Write(msg)
if err != nil {
t.Fatalf("Write should not return error: %v", err)
}
if n != len(msg) {
t.Errorf("Write returned %d, expected %d", n, len(msg))
}
}
func TestLogWriter_Write_Exception(t *testing.T) {
lw := NewLogWriter()
msg := []byte("unhandled exception occurred")
n, err := lw.Write(msg)
if err != nil {
t.Fatalf("Write should not return error: %v", err)
}
if n != len(msg) {
t.Errorf("Write returned %d, expected %d", n, len(msg))
}
}
func TestLogWriter_Write_EmptyMessage(t *testing.T) {
lw := NewLogWriter()
n, err := lw.Write([]byte(""))
if err != nil {
t.Fatalf("Write should not error: %v", err)
}
if n != 0 {
t.Errorf("Write returned %d, expected 0", n)
}
}
func TestLogWriter_Write_TLSErrorSuppressed(t *testing.T) {
lw := NewLogWriter()
msg := []byte("some tls handshake error occurred")
n, err := lw.Write(msg)
if err != nil {
t.Fatalf("Write should not return error: %v", err)
}
if n != len(msg) {
t.Errorf("Write returned %d, expected %d", n, len(msg))
}
}
func TestLogWriter_Write_FailedKeyword(t *testing.T) {
lw := NewLogWriter()
msg := []byte("connection failed to remote")
n, err := lw.Write(msg)
if err != nil {
t.Fatalf("Write should not return error: %v", err)
}
if n != len(msg) {
t.Errorf("Write returned %d, expected %d", n, len(msg))
}
}