mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-06 13:14:11 +00:00
test: add unit tests for critical modules and fix flaky redirect middleware
- Add tests for config, database, model, util/common, util/crypto, util/random, web/middleware, web/service, and xray packages - Fix redirect middleware using slice instead of map to guarantee deterministic longest-prefix-first matching order
This commit is contained in:
parent
d43f8683b0
commit
09f84782b0
13 changed files with 1674 additions and 9 deletions
131
config/config_test.go
Normal file
131
config/config_test.go
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetVersion(t *testing.T) {
|
||||
v := GetVersion()
|
||||
if v == "" {
|
||||
// version file might be empty in test, that's ok
|
||||
t.Log("version is empty (expected in test environment)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetName(t *testing.T) {
|
||||
n := GetName()
|
||||
if n == "" {
|
||||
t.Fatal("name should not be empty")
|
||||
}
|
||||
if strings.TrimSpace(n) != n {
|
||||
t.Error("name should be trimmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDebugDefault(t *testing.T) {
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
if IsDebug() {
|
||||
t.Error("IsDebug should return false by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDebugTrue(t *testing.T) {
|
||||
t.Setenv("XUI_DEBUG", "true")
|
||||
if !IsDebug() {
|
||||
t.Error("IsDebug should return true when XUI_DEBUG=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDebugFalse(t *testing.T) {
|
||||
t.Setenv("XUI_DEBUG", "false")
|
||||
if IsDebug() {
|
||||
t.Error("IsDebug should return false when XUI_DEBUG=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogLevelDefault(t *testing.T) {
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
t.Setenv("XUI_LOG_LEVEL", "")
|
||||
if GetLogLevel() != Info {
|
||||
t.Errorf("default log level should be Info, got %s", GetLogLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogLevelDebug(t *testing.T) {
|
||||
t.Setenv("XUI_DEBUG", "true")
|
||||
if GetLogLevel() != Debug {
|
||||
t.Errorf("log level should be Debug when XUI_DEBUG=true, got %s", GetLogLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogLevelCustom(t *testing.T) {
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
t.Setenv("XUI_LOG_LEVEL", "warning")
|
||||
if GetLogLevel() != Warning {
|
||||
t.Errorf("log level should be Warning, got %s", GetLogLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBinFolderPathDefault(t *testing.T) {
|
||||
t.Setenv("XUI_BIN_FOLDER", "")
|
||||
if GetBinFolderPath() != "bin" {
|
||||
t.Errorf("default bin folder should be 'bin', got %s", GetBinFolderPath())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBinFolderPathCustom(t *testing.T) {
|
||||
t.Setenv("XUI_BIN_FOLDER", "/custom/bin")
|
||||
if GetBinFolderPath() != "/custom/bin" {
|
||||
t.Errorf("bin folder should be '/custom/bin', got %s", GetBinFolderPath())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBFolderPathDefault(t *testing.T) {
|
||||
t.Setenv("XUI_DB_FOLDER", "")
|
||||
folder := GetDBFolderPath()
|
||||
// On Linux without env var, should be "/etc/x-ui"
|
||||
if folder != "/etc/x-ui" {
|
||||
t.Errorf("default DB folder should be '/etc/x-ui', got %s", folder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBFolderPathCustom(t *testing.T) {
|
||||
t.Setenv("XUI_DB_FOLDER", "/tmp/test-db")
|
||||
if GetDBFolderPath() != "/tmp/test-db" {
|
||||
t.Errorf("DB folder should be '/tmp/test-db', got %s", GetDBFolderPath())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBPath(t *testing.T) {
|
||||
t.Setenv("XUI_DB_FOLDER", "/tmp/test")
|
||||
dbPath := GetDBPath()
|
||||
expected := "/tmp/test/" + GetName() + ".db"
|
||||
if dbPath != expected {
|
||||
t.Errorf("GetDBPath() = %q, want %q", dbPath, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettingPath(t *testing.T) {
|
||||
t.Setenv("XUI_DB_FOLDER", "/tmp/test")
|
||||
settingPath := GetSettingPath()
|
||||
expected := "/tmp/test/" + GetName() + ".json"
|
||||
if settingPath != expected {
|
||||
t.Errorf("GetSettingPath() = %q, want %q", settingPath, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogFolderDefault(t *testing.T) {
|
||||
t.Setenv("XUI_LOG_FOLDER", "")
|
||||
folder := GetLogFolder()
|
||||
if folder != "/var/log/x-ui" {
|
||||
t.Errorf("default log folder should be '/var/log/x-ui', got %s", folder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogFolderCustom(t *testing.T) {
|
||||
t.Setenv("XUI_LOG_FOLDER", "/custom/logs")
|
||||
if GetLogFolder() != "/custom/logs" {
|
||||
t.Errorf("log folder should be '/custom/logs', got %s", GetLogFolder())
|
||||
}
|
||||
}
|
||||
214
database/db_test.go
Normal file
214
database/db_test.go
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v2/database/model"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
CloseDB()
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsSQLiteDB_ValidFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "valid.db")
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
defer CloseDB()
|
||||
|
||||
f, err := os.Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Open failed: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ok, err := IsSQLiteDB(f)
|
||||
if err != nil {
|
||||
t.Fatalf("IsSQLiteDB error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Error("IsSQLiteDB should return true for a valid SQLite file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSQLiteDB_InvalidFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
notADB := filepath.Join(tmpDir, "notdb.txt")
|
||||
if err := os.WriteFile(notADB, []byte("this is not a database"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := os.Open(notADB)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ok, err := IsSQLiteDB(f)
|
||||
if err != nil {
|
||||
t.Fatalf("IsSQLiteDB error: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("IsSQLiteDB should return false for a non-SQLite file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSQLiteDB_EmptyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
empty := filepath.Join(tmpDir, "empty.db")
|
||||
if err := os.WriteFile(empty, []byte{}, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := os.Open(empty)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ok, err := IsSQLiteDB(f)
|
||||
// Empty file returns EOF since there aren't enough bytes to read
|
||||
if err == nil && ok {
|
||||
t.Error("IsSQLiteDB should return false for an empty file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSQLiteDB_WrongContent(t *testing.T) {
|
||||
// File with 16 bytes (matching SQLite header length) but wrong signature content
|
||||
r := bytes.NewReader([]byte("SQLite for !!"))
|
||||
ok, err := IsSQLiteDB(r)
|
||||
if err != nil {
|
||||
t.Fatalf("IsSQLiteDB error: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("IsSQLiteDB should return false for wrong signature content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitDB_CreatesTables(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
// Verify all tables exist by querying them
|
||||
tables := []string{"users", "inbounds", "outbound_traffics", "settings", "inbound_client_ips", "client_traffics", "history_of_seeders"}
|
||||
for _, table := range tables {
|
||||
var count int64
|
||||
if err := db.Table(table).Count(&count).Error; err != nil {
|
||||
t.Errorf("table %q should exist but got error: %v", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitDB_CreatesDefaultUser(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
var user model.User
|
||||
if err := db.First(&user).Error; err != nil {
|
||||
t.Fatalf("should have a default user: %v", err)
|
||||
}
|
||||
if user.Username != "admin" {
|
||||
t.Errorf("default username should be 'admin', got %q", user.Username)
|
||||
}
|
||||
if user.Role != "admin" {
|
||||
t.Errorf("default role should be 'admin', got %q", user.Role)
|
||||
}
|
||||
// Password should be a bcrypt hash, not plaintext
|
||||
if user.Password == "admin" {
|
||||
t.Error("default password should be hashed, not plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitDB_Idempotent(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
dbPath := filepath.Join(tmpDir, "idempotent.db")
|
||||
|
||||
// First init
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
t.Fatalf("first InitDB failed: %v", err)
|
||||
}
|
||||
CloseDB()
|
||||
|
||||
// Second init on the same file should not fail
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
t.Fatalf("second InitDB failed: %v", err)
|
||||
}
|
||||
defer CloseDB()
|
||||
|
||||
// Should still have exactly one default user
|
||||
var count int64
|
||||
db.Model(&model.User{}).Count(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 user after second init, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLiteDB_ValidDB(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "valid.db")
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
CloseDB()
|
||||
|
||||
if err := ValidateSQLiteDB(dbPath); err != nil {
|
||||
t.Errorf("ValidateSQLiteDB should pass for valid DB: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLiteDB_NonExistent(t *testing.T) {
|
||||
err := ValidateSQLiteDB("/tmp/does-not-exist-12345.db")
|
||||
if err == nil {
|
||||
t.Error("ValidateSQLiteDB should fail for non-existent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLiteDB_CorruptDB(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "corrupt.db")
|
||||
// Write garbage that looks like SQLite header but is corrupt
|
||||
garbage := make([]byte, 4096)
|
||||
copy(garbage, []byte("SQLite format 3\x00"))
|
||||
if err := os.WriteFile(dbPath, garbage, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := ValidateSQLiteDB(dbPath)
|
||||
if err == nil {
|
||||
t.Error("ValidateSQLiteDB should fail for corrupt DB")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNotFound(t *testing.T) {
|
||||
if IsNotFound(nil) {
|
||||
t.Error("IsNotFound should return false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitUser_OnlyOnce(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
// initUser should not create a second user when table is not empty
|
||||
if err := initUser(); err != nil {
|
||||
t.Fatalf("initUser error: %v", err)
|
||||
}
|
||||
|
||||
var count int64
|
||||
db.Model(&model.User{}).Count(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 user, got %d", count)
|
||||
}
|
||||
}
|
||||
61
database/model/model_test.go
Normal file
61
database/model/model_test.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenXrayInboundConfig_EmptyListen(t *testing.T) {
|
||||
in := &Inbound{
|
||||
Listen: "",
|
||||
Port: 443,
|
||||
Protocol: VLESS,
|
||||
Settings: `{"clients":[]}`,
|
||||
Tag: "test-inbound",
|
||||
Sniffing: `{"enabled":true}`,
|
||||
}
|
||||
cfg := in.GenXrayInboundConfig()
|
||||
if cfg == nil {
|
||||
t.Fatal("GenXrayInboundConfig should not return nil")
|
||||
}
|
||||
// Empty listen should default to 0.0.0.0
|
||||
expected := `"0.0.0.0"`
|
||||
if string(cfg.Listen) != expected {
|
||||
t.Errorf("Listen should default to %s, got %s", expected, string(cfg.Listen))
|
||||
}
|
||||
if cfg.Port != 443 {
|
||||
t.Errorf("Port should be 443, got %d", cfg.Port)
|
||||
}
|
||||
if cfg.Protocol != "vless" {
|
||||
t.Errorf("Protocol should be vless, got %q", cfg.Protocol)
|
||||
}
|
||||
if cfg.Tag != "test-inbound" {
|
||||
t.Errorf("Tag should be test-inbound, got %q", cfg.Tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenXrayInboundConfig_CustomListen(t *testing.T) {
|
||||
in := &Inbound{
|
||||
Listen: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: VMESS,
|
||||
Tag: "custom",
|
||||
}
|
||||
cfg := in.GenXrayInboundConfig()
|
||||
expected := `"127.0.0.1"`
|
||||
if string(cfg.Listen) != expected {
|
||||
t.Errorf("Listen should be %s, got %s", expected, string(cfg.Listen))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenXrayInboundConfig_EmptySettings(t *testing.T) {
|
||||
in := &Inbound{
|
||||
Port: 443,
|
||||
Protocol: Trojan,
|
||||
}
|
||||
cfg := in.GenXrayInboundConfig()
|
||||
if cfg == nil {
|
||||
t.Fatal("GenXrayInboundConfig should not return nil")
|
||||
}
|
||||
// Empty string Settings produces a nil RawMessage since json_util.RawMessage("") may be nil
|
||||
// Just verify no panic occurred
|
||||
}
|
||||
162
util/common/common_test.go
Normal file
162
util/common/common_test.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewErrorf(t *testing.T) {
|
||||
err := NewErrorf("invalid port: %d", 8080)
|
||||
if err == nil {
|
||||
t.Fatal("NewErrorf should return non-nil error")
|
||||
}
|
||||
expected := "invalid port: 8080"
|
||||
if err.Error() != expected {
|
||||
t.Errorf("NewErrorf returned %q, expected %q", err.Error(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
err := NewError("something", " went wrong")
|
||||
if err == nil {
|
||||
t.Fatal("NewError should return non-nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "something") {
|
||||
t.Errorf("NewError should contain 'something', got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoverWithoutPanic(t *testing.T) {
|
||||
recovered := Recover("")
|
||||
if recovered != nil {
|
||||
t.Errorf("Recover should return nil when no panic occurred, got %v", recovered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int64
|
||||
expected string
|
||||
}{
|
||||
{0, "0.00B"},
|
||||
{1, "1.00B"},
|
||||
{512, "512.00B"},
|
||||
{1023, "1023.00B"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := FormatTraffic(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("FormatTraffic(%d) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficKB(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int64
|
||||
expected string
|
||||
}{
|
||||
{1024, "1.00KB"},
|
||||
{1536, "1.50KB"},
|
||||
{2048, "2.00KB"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := FormatTraffic(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("FormatTraffic(%d) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficMB(t *testing.T) {
|
||||
result := FormatTraffic(1048576) // 1 MB
|
||||
expected := "1.00MB"
|
||||
if result != expected {
|
||||
t.Errorf("FormatTraffic(1048576) = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficGB(t *testing.T) {
|
||||
result := FormatTraffic(1073741824) // 1 GB
|
||||
expected := "1.00GB"
|
||||
if result != expected {
|
||||
t.Errorf("FormatTraffic(1073741824) = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficTB(t *testing.T) {
|
||||
result := FormatTraffic(1099511627776) // 1 TB
|
||||
expected := "1.00TB"
|
||||
if result != expected {
|
||||
t.Errorf("FormatTraffic(1099511627776) = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficPB(t *testing.T) {
|
||||
result := FormatTraffic(1125899906842624) // 1 PB
|
||||
expected := "1.00PB"
|
||||
if result != expected {
|
||||
t.Errorf("FormatTraffic(1125899906842624) = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTrafficLargePB(t *testing.T) {
|
||||
// Value exceeding PB should stay in PB
|
||||
result := FormatTraffic(11258999068426240) // 10 PB
|
||||
if !strings.HasSuffix(result, "PB") {
|
||||
t.Errorf("FormatTraffic should cap at PB, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineAllNil(t *testing.T) {
|
||||
err := Combine(nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Combine(nil, nil, nil) should return nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineNoArgs(t *testing.T) {
|
||||
err := Combine()
|
||||
if err != nil {
|
||||
t.Errorf("Combine() should return nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineSingleError(t *testing.T) {
|
||||
input := errors.New("test error")
|
||||
err := Combine(input)
|
||||
if err == nil {
|
||||
t.Fatal("Combine should return non-nil when an error is present")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "test error") {
|
||||
t.Errorf("Combine should contain 'test error', got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineMultipleErrors(t *testing.T) {
|
||||
err1 := errors.New("error 1")
|
||||
err2 := errors.New("error 2")
|
||||
combined := Combine(err1, nil, err2)
|
||||
if combined == nil {
|
||||
t.Fatal("Combine should return non-nil when errors are present")
|
||||
}
|
||||
s := combined.Error()
|
||||
if !strings.Contains(s, "error 1") {
|
||||
t.Errorf("Combined error should contain 'error 1', got %q", s)
|
||||
}
|
||||
if !strings.Contains(s, "error 2") {
|
||||
t.Errorf("Combined error should contain 'error 2', got %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineFiltersNils(t *testing.T) {
|
||||
err1 := errors.New("real error")
|
||||
combined := Combine(nil, err1, nil)
|
||||
if combined == nil {
|
||||
t.Fatal("Combine should return non-nil when at least one error is present")
|
||||
}
|
||||
if !strings.Contains(combined.Error(), "real error") {
|
||||
t.Errorf("Combined error should contain 'real error', got %q", combined.Error())
|
||||
}
|
||||
}
|
||||
72
util/crypto/crypto_test.go
Normal file
72
util/crypto/crypto_test.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashPasswordAsBcrypt(t *testing.T) {
|
||||
hash, err := HashPasswordAsBcrypt("password123")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPasswordAsBcrypt failed: %v", err)
|
||||
}
|
||||
if hash == "" {
|
||||
t.Fatal("hash should not be empty")
|
||||
}
|
||||
if hash == "password123" {
|
||||
t.Fatal("hash should not equal the plaintext password")
|
||||
}
|
||||
if !strings.HasPrefix(hash, "$2a$") && !strings.HasPrefix(hash, "$2b$") {
|
||||
t.Fatalf("hash should have bcrypt prefix, got: %s", hash[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashAndCheckRoundTrip(t *testing.T) {
|
||||
passwords := []string{
|
||||
"password123",
|
||||
"",
|
||||
"very-long-password-with-special-chars-!@#$%^&*()",
|
||||
"unicode-密码-test",
|
||||
}
|
||||
for _, pw := range passwords {
|
||||
hash, err := HashPasswordAsBcrypt(pw)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPasswordAsBcrypt(%q) failed: %v", pw, err)
|
||||
}
|
||||
if !CheckPasswordHash(hash, pw) {
|
||||
t.Errorf("CheckPasswordHash should return true for correct password %q", pw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPasswordHashWrongPassword(t *testing.T) {
|
||||
hash, err := HashPasswordAsBcrypt("correct-password")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPasswordAsBcrypt failed: %v", err)
|
||||
}
|
||||
if CheckPasswordHash(hash, "wrong-password") {
|
||||
t.Error("CheckPasswordHash should return false for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPasswordHashInvalidHash(t *testing.T) {
|
||||
if CheckPasswordHash("not-a-valid-hash", "password") {
|
||||
t.Error("CheckPasswordHash should return false for invalid hash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentPasswordsProduceDifferentHashes(t *testing.T) {
|
||||
hash1, _ := HashPasswordAsBcrypt("password1")
|
||||
hash2, _ := HashPasswordAsBcrypt("password2")
|
||||
if hash1 == hash2 {
|
||||
t.Error("different passwords should produce different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamePasswordProducesDifferentHashes(t *testing.T) {
|
||||
hash1, _ := HashPasswordAsBcrypt("same-password")
|
||||
hash2, _ := HashPasswordAsBcrypt("same-password")
|
||||
if hash1 == hash2 {
|
||||
t.Error("bcrypt should use different salts, producing different hashes for same password")
|
||||
}
|
||||
}
|
||||
67
util/random/random_test.go
Normal file
67
util/random/random_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSeqLength(t *testing.T) {
|
||||
lengths := []int{0, 1, 10, 32, 64, 128}
|
||||
for _, n := range lengths {
|
||||
s := Seq(n)
|
||||
if len([]rune(s)) != n {
|
||||
t.Errorf("Seq(%d) returned string of length %d", n, len([]rune(s)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqCharset(t *testing.T) {
|
||||
allChars := make(map[rune]bool)
|
||||
for _, r := range allSeq {
|
||||
allChars[r] = true
|
||||
}
|
||||
|
||||
s := Seq(1000)
|
||||
for i, r := range s {
|
||||
if !allChars[r] {
|
||||
t.Errorf("Seq produced invalid character %q at index %d", r, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqUniqueness(t *testing.T) {
|
||||
// Generate several strings and check they're not all identical
|
||||
seen := make(map[string]bool)
|
||||
for range 10 {
|
||||
seen[Seq(32)] = true
|
||||
}
|
||||
if len(seen) < 2 {
|
||||
t.Error("Seq(32) produced identical strings across 10 calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqEmpty(t *testing.T) {
|
||||
s := Seq(0)
|
||||
if s != "" {
|
||||
t.Errorf("Seq(0) should return empty string, got %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumRange(t *testing.T) {
|
||||
for _, n := range []int{1, 5, 10, 100, 1000} {
|
||||
for range 100 {
|
||||
r := Num(n)
|
||||
if r < 0 || r >= n {
|
||||
t.Errorf("Num(%d) returned %d, expected [0, %d)", n, r, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumOne(t *testing.T) {
|
||||
for range 50 {
|
||||
r := Num(1)
|
||||
if r != 0 {
|
||||
t.Errorf("Num(1) should always return 0, got %d", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
320
web/middleware/middleware_test.go
Normal file
320
web/middleware/middleware_test.go
Normal file
|
|
@ -0,0 +1,320 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestRedirectMiddleware_XUIToPanel(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RedirectMiddleware("/"))
|
||||
r.GET("/panel/*path", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "panel")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/xui/settings", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("expected 301, got %d", w.Code)
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != "/panel/settings" {
|
||||
t.Errorf("expected redirect to /panel/settings, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectMiddleware_XUIAPIToPanelAPI(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RedirectMiddleware("/"))
|
||||
r.GET("/panel/api/*path", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "api")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/xui/API/inbounds", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("expected 301, got %d", w.Code)
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != "/panel/api/inbounds" {
|
||||
t.Errorf("expected redirect to /panel/api/inbounds, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectMiddleware_PanelAPICase(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RedirectMiddleware("/"))
|
||||
r.GET("/panel/api/*path", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "api")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/panel/API/list", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("expected 301, got %d", w.Code)
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != "/panel/api/list" {
|
||||
t.Errorf("expected redirect to /panel/api/list, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectMiddleware_NoRedirect(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RedirectMiddleware("/"))
|
||||
r.GET("/panel/settings", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/panel/settings", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectMiddleware_WithBasePath(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RedirectMiddleware("/base/"))
|
||||
r.GET("/base/panel/*path", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/base/xui/settings", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("expected 301, got %d", w.Code)
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != "/base/panel/settings" {
|
||||
t.Errorf("expected redirect to /base/panel/settings, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainValidatorMiddleware_MatchingDomain(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(DomainValidatorMiddleware("example.com"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Host = "example.com"
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainValidatorMiddleware_MatchingDomainWithPort(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(DomainValidatorMiddleware("example.com"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Host = "example.com:8443"
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for matching domain with port, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainValidatorMiddleware_NonMatchingDomain(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(DomainValidatorMiddleware("example.com"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Host = "evil.com"
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainValidatorMiddleware_Subdomain(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(DomainValidatorMiddleware("example.com"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Host = "sub.example.com"
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403 for subdomain, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_FirstRequest(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RateLimitMiddleware(5, time.Minute))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for first request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_WithinLimit(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RateLimitMiddleware(3, time.Minute))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
for i := range 3 {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.2:12345"
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("request %d: expected 200, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_ExceedsLimit(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RateLimitMiddleware(2, time.Minute))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// First 2 should pass
|
||||
for range 2 {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.3:12345"
|
||||
r.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// 3rd should be rate limited
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.3:12345"
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected 429, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_XRealIP(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RateLimitMiddleware(2, time.Minute))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
for range 2 {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "10.0.0.1")
|
||||
r.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "10.0.0.1")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected 429 with X-Real-IP, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_XForwardedFor(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RateLimitMiddleware(2, time.Minute))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
for range 2 {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3")
|
||||
r.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected 429 with X-Forwarded-For, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_DifferentIPsIndependent(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RateLimitMiddleware(1, time.Minute))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// Exhaust limit for IP 1
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "10.0.0.10")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "10.0.0.10")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("IP 1 second request should be 429, got %d", w.Code)
|
||||
}
|
||||
|
||||
// IP 2 should still be allowed
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "10.0.0.20")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("different IP should get 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
|
@ -11,17 +11,20 @@ import (
|
|||
// It provides backward compatibility by redirecting old '/xui' paths to new '/panel' paths,
|
||||
// including API endpoints. The middleware performs permanent redirects (301) for SEO purposes.
|
||||
func RedirectMiddleware(basePath string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Redirect from old '/xui' path to '/panel'
|
||||
redirects := map[string]string{
|
||||
"panel/API": "panel/api",
|
||||
"xui/API": "panel/api",
|
||||
"xui": "panel",
|
||||
}
|
||||
// Use a slice to guarantee longest-prefix-first matching order.
|
||||
// A map would have nondeterministic iteration, causing "/xui/API" to
|
||||
// sometimes match the shorter "/xui" rule instead.
|
||||
redirects := []struct{ from, to string }{
|
||||
{"panel/API", "panel/api"},
|
||||
{"xui/API", "panel/api"},
|
||||
{"xui", "panel"},
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for from, to := range redirects {
|
||||
from, to = basePath+from, basePath+to
|
||||
for _, r := range redirects {
|
||||
from := basePath + r.from
|
||||
to := basePath + r.to
|
||||
|
||||
if strings.HasPrefix(path, from) {
|
||||
newPath := to + path[len(from):]
|
||||
|
|
|
|||
100
web/service/server_test.go
Normal file
100
web/service/server_test.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidGeofileName_Valid(t *testing.T) {
|
||||
svc := &ServerService{}
|
||||
valid := []string{
|
||||
"geoip.dat",
|
||||
"geosite.dat",
|
||||
"geoip_IR.dat",
|
||||
"geoip_RU.dat",
|
||||
"geosite_IR.dat",
|
||||
"custom-file_v2.dat",
|
||||
}
|
||||
for _, name := range valid {
|
||||
if !svc.IsValidGeofileName(name) {
|
||||
t.Errorf("IsValidGeofileName(%q) should return true", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidGeofileName_PathTraversal(t *testing.T) {
|
||||
svc := &ServerService{}
|
||||
invalid := []string{
|
||||
"../geoip.dat",
|
||||
"../../etc/passwd",
|
||||
"subdir/geoip.dat",
|
||||
"geoip.dat/../../../etc",
|
||||
"..\\geoip.dat",
|
||||
}
|
||||
for _, name := range invalid {
|
||||
if svc.IsValidGeofileName(name) {
|
||||
t.Errorf("IsValidGeofileName(%q) should return false (path traversal)", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidGeofileName_Empty(t *testing.T) {
|
||||
svc := &ServerService{}
|
||||
if svc.IsValidGeofileName("") {
|
||||
t.Error("IsValidGeofileName(\"\") should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidGeofileName_NoDatExtension(t *testing.T) {
|
||||
svc := &ServerService{}
|
||||
invalid := []string{
|
||||
"geoip.txt",
|
||||
"geosite",
|
||||
"file.exe",
|
||||
"script.sh",
|
||||
}
|
||||
for _, name := range invalid {
|
||||
if svc.IsValidGeofileName(name) {
|
||||
t.Errorf("IsValidGeofileName(%q) should return false (no .dat extension)", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidGeofileName_SpecialChars(t *testing.T) {
|
||||
svc := &ServerService{}
|
||||
invalid := []string{
|
||||
"geoip$.dat",
|
||||
"geoip!.dat",
|
||||
"geoip;.dat",
|
||||
"geoip .dat",
|
||||
"geoip@attack.dat",
|
||||
}
|
||||
for _, name := range invalid {
|
||||
if svc.IsValidGeofileName(name) {
|
||||
t.Errorf("IsValidGeofileName(%q) should return false (special chars)", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEntryContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
line string
|
||||
suffixes []string
|
||||
want bool
|
||||
}{
|
||||
// The implementation checks strings.Contains(line, sfx+"]")
|
||||
{"line with freedom]", []string{"freedom"}, true},
|
||||
{"line with blackhole]", []string{"blackhole"}, true},
|
||||
{"freedom outbound", []string{"freedom"}, false},
|
||||
{"blackhole outbound", []string{"blackhole"}, false},
|
||||
{"freedom outbound", []string{"blackhole"}, false},
|
||||
{"some log line", []string{}, false},
|
||||
{"line with freedom] and blackhole]", []string{"freedom", "blackhole"}, true},
|
||||
{"line with freedom] and blackhole]", []string{"other"}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := logEntryContains(tt.line, tt.suffixes)
|
||||
if got != tt.want {
|
||||
t.Errorf("logEntryContains(%q, %v) = %v, want %v", tt.line, tt.suffixes, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
147
web/service/user_test.go
Normal file
147
web/service/user_test.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v2/database"
|
||||
"github.com/mhsanaei/3x-ui/v2/database/model"
|
||||
"github.com/mhsanaei/3x-ui/v2/util/crypto"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
if err := database.InitDB(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
database.CloseDB()
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFirstUser(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
user, err := svc.GetFirstUser()
|
||||
if err != nil {
|
||||
t.Fatalf("GetFirstUser error: %v", err)
|
||||
}
|
||||
if user.Username != "admin" {
|
||||
t.Errorf("expected username 'admin', got %q", user.Username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUser_ValidCredentials(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
user, err := svc.CheckUser("admin", "admin", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CheckUser error: %v", err)
|
||||
}
|
||||
if user.Username != "admin" {
|
||||
t.Errorf("expected username 'admin', got %q", user.Username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUser_WrongPassword(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
_, err := svc.CheckUser("admin", "wrongpassword", "")
|
||||
if err == nil {
|
||||
t.Error("CheckUser should fail with wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUser_NonExistentUser(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
_, err := svc.CheckUser("nonexistent", "password", "")
|
||||
if err == nil {
|
||||
t.Error("CheckUser should fail for non-existent user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFirstUser(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
err := svc.UpdateFirstUser("newadmin", "newpassword123")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateFirstUser error: %v", err)
|
||||
}
|
||||
|
||||
// Verify login with new credentials
|
||||
user, err := svc.CheckUser("newadmin", "newpassword123", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CheckUser with new credentials error: %v", err)
|
||||
}
|
||||
if user.Username != "newadmin" {
|
||||
t.Errorf("expected username 'newadmin', got %q", user.Username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFirstUser_EmptyUsername(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
err := svc.UpdateFirstUser("", "password")
|
||||
if err == nil {
|
||||
t.Error("UpdateFirstUser should fail with empty username")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFirstUser_EmptyPassword(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &UserService{}
|
||||
err := svc.UpdateFirstUser("admin", "")
|
||||
if err == nil {
|
||||
t.Error("UpdateFirstUser should fail with empty password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFirstUser_CreateWhenNone(t *testing.T) {
|
||||
// Use a fresh temp dir so no users table data exists
|
||||
tmpDir := t.TempDir()
|
||||
os.Setenv("XUI_DEBUG", "")
|
||||
os.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||
defer func() {
|
||||
os.Unsetenv("XUI_DEBUG")
|
||||
os.Unsetenv("XUI_DB_FOLDER")
|
||||
}()
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "empty.db")
|
||||
if err := database.InitDB(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
defer database.CloseDB()
|
||||
|
||||
// Delete all users to simulate empty table
|
||||
database.GetDB().Where("1 = 1").Delete(&model.User{})
|
||||
|
||||
svc := &UserService{}
|
||||
err := svc.UpdateFirstUser("firstadmin", "firstpass")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateFirstUser should create user when table is empty: %v", err)
|
||||
}
|
||||
|
||||
user, err := svc.GetFirstUser()
|
||||
if err != nil {
|
||||
t.Fatalf("GetFirstUser error: %v", err)
|
||||
}
|
||||
if user.Username != "firstadmin" {
|
||||
t.Errorf("expected username 'firstadmin', got %q", user.Username)
|
||||
}
|
||||
if !crypto.CheckPasswordHash(user.Password, "firstpass") {
|
||||
t.Error("password hash should match 'firstpass'")
|
||||
}
|
||||
}
|
||||
172
xray/api_test.go
Normal file
172
xray/api_test.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
package xray
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessTraffic_Inbound(t *testing.T) {
|
||||
matches := []string{
|
||||
"inbound>>>vmess-in>>>traffic>>>uplink",
|
||||
"inbound",
|
||||
"vmess-in",
|
||||
"uplink",
|
||||
}
|
||||
trafficMap := make(map[string]*Traffic)
|
||||
processTraffic(matches, 1024, trafficMap)
|
||||
|
||||
tr, ok := trafficMap["vmess-in"]
|
||||
if !ok {
|
||||
t.Fatal("should have vmess-in entry")
|
||||
}
|
||||
if !tr.IsInbound {
|
||||
t.Error("should be inbound")
|
||||
}
|
||||
if tr.IsOutbound {
|
||||
t.Error("should not be outbound")
|
||||
}
|
||||
if tr.Tag != "vmess-in" {
|
||||
t.Errorf("tag should be vmess-in, got %q", tr.Tag)
|
||||
}
|
||||
if tr.Up != 1024 {
|
||||
t.Errorf("up should be 1024, got %d", tr.Up)
|
||||
}
|
||||
if tr.Down != 0 {
|
||||
t.Errorf("down should be 0, got %d", tr.Down)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessTraffic_Outbound(t *testing.T) {
|
||||
matches := []string{
|
||||
"outbound>>>direct>>>traffic>>>downlink",
|
||||
"outbound",
|
||||
"direct",
|
||||
"downlink",
|
||||
}
|
||||
trafficMap := make(map[string]*Traffic)
|
||||
processTraffic(matches, 2048, trafficMap)
|
||||
|
||||
tr, ok := trafficMap["direct"]
|
||||
if !ok {
|
||||
t.Fatal("should have direct entry")
|
||||
}
|
||||
if tr.IsOutbound != true {
|
||||
t.Error("should be outbound")
|
||||
}
|
||||
if tr.IsInbound != false {
|
||||
t.Error("should not be inbound")
|
||||
}
|
||||
if tr.Down != 2048 {
|
||||
t.Errorf("down should be 2048, got %d", tr.Down)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessTraffic_ApiTagSkipped(t *testing.T) {
|
||||
matches := []string{
|
||||
"inbound>>>api>>>traffic>>>uplink",
|
||||
"inbound",
|
||||
"api",
|
||||
"uplink",
|
||||
}
|
||||
trafficMap := make(map[string]*Traffic)
|
||||
processTraffic(matches, 1024, trafficMap)
|
||||
|
||||
if _, ok := trafficMap["api"]; ok {
|
||||
t.Error("api tag should be skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessTraffic_Aggregates(t *testing.T) {
|
||||
trafficMap := make(map[string]*Traffic)
|
||||
|
||||
// First: uplink
|
||||
processTraffic([]string{"", "inbound", "test-tag", "uplink"}, 100, trafficMap)
|
||||
// Second: downlink on same tag
|
||||
processTraffic([]string{"", "inbound", "test-tag", "downlink"}, 200, trafficMap)
|
||||
|
||||
tr := trafficMap["test-tag"]
|
||||
if tr.Up != 100 {
|
||||
t.Errorf("expected up=100, got %d", tr.Up)
|
||||
}
|
||||
if tr.Down != 200 {
|
||||
t.Errorf("expected down=200, got %d", tr.Down)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClientTraffic(t *testing.T) {
|
||||
clientMap := make(map[string]*ClientTraffic)
|
||||
|
||||
processClientTraffic([]string{"", "user@example.com", "uplink"}, 500, clientMap)
|
||||
processClientTraffic([]string{"", "user@example.com", "downlink"}, 1500, clientMap)
|
||||
|
||||
ct, ok := clientMap["user@example.com"]
|
||||
if !ok {
|
||||
t.Fatal("should have client entry")
|
||||
}
|
||||
if ct.Email != "user@example.com" {
|
||||
t.Errorf("email should be user@example.com, got %q", ct.Email)
|
||||
}
|
||||
if ct.Up != 500 {
|
||||
t.Errorf("up should be 500, got %d", ct.Up)
|
||||
}
|
||||
if ct.Down != 1500 {
|
||||
t.Errorf("down should be 1500, got %d", ct.Down)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClientTraffic_MultipleClients(t *testing.T) {
|
||||
clientMap := make(map[string]*ClientTraffic)
|
||||
|
||||
processClientTraffic([]string{"", "user1@test.com", "uplink"}, 100, clientMap)
|
||||
processClientTraffic([]string{"", "user2@test.com", "uplink"}, 200, clientMap)
|
||||
|
||||
if len(clientMap) != 2 {
|
||||
t.Errorf("expected 2 clients, got %d", len(clientMap))
|
||||
}
|
||||
if clientMap["user1@test.com"].Up != 100 {
|
||||
t.Error("user1 up mismatch")
|
||||
}
|
||||
if clientMap["user2@test.com"].Up != 200 {
|
||||
t.Error("user2 up mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToSlice_Empty(t *testing.T) {
|
||||
m := make(map[string]*Traffic)
|
||||
result := mapToSlice(m)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected empty slice, got length %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToSlice_Nil(t *testing.T) {
|
||||
var m map[string]*Traffic
|
||||
result := mapToSlice(m)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected empty slice for nil map, got length %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToSlice_Multiple(t *testing.T) {
|
||||
m := map[string]*Traffic{
|
||||
"a": {Tag: "a", Up: 1},
|
||||
"b": {Tag: "b", Up: 2},
|
||||
"c": {Tag: "c", Up: 3},
|
||||
}
|
||||
result := mapToSlice(m)
|
||||
if len(result) != 3 {
|
||||
t.Errorf("expected 3 elements, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXrayAPI_Init_InvalidPort(t *testing.T) {
|
||||
api := &XrayAPI{}
|
||||
if err := api.Init(0); err == nil {
|
||||
t.Error("Init with port 0 should return error")
|
||||
}
|
||||
if err := api.Init(-1); err == nil {
|
||||
t.Error("Init with negative port should return error")
|
||||
}
|
||||
if err := api.Init(70000); err == nil {
|
||||
t.Error("Init with port > 65535 should return error")
|
||||
}
|
||||
}
|
||||
125
xray/config_test.go
Normal file
125
xray/config_test.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package xray
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInboundConfig_Equals_Equal(t *testing.T) {
|
||||
a := &InboundConfig{
|
||||
Listen: []byte(`"0.0.0.0"`),
|
||||
Port: 443,
|
||||
Protocol: "vless",
|
||||
Settings: []byte(`{}`),
|
||||
StreamSettings: []byte(`{"network":"tcp"}`),
|
||||
Tag: "inbound-443",
|
||||
Sniffing: []byte(`{"enabled":true}`),
|
||||
}
|
||||
b := &InboundConfig{
|
||||
Listen: []byte(`"0.0.0.0"`),
|
||||
Port: 443,
|
||||
Protocol: "vless",
|
||||
Settings: []byte(`{}`),
|
||||
StreamSettings: []byte(`{"network":"tcp"}`),
|
||||
Tag: "inbound-443",
|
||||
Sniffing: []byte(`{"enabled":true}`),
|
||||
}
|
||||
if !a.Equals(b) {
|
||||
t.Error("identical InboundConfigs should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfig_Equals_DifferentPort(t *testing.T) {
|
||||
a := &InboundConfig{Port: 443, Protocol: "vless"}
|
||||
b := &InboundConfig{Port: 8443, Protocol: "vless"}
|
||||
if a.Equals(b) {
|
||||
t.Error("InboundConfigs with different ports should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfig_Equals_DifferentProtocol(t *testing.T) {
|
||||
a := &InboundConfig{Port: 443, Protocol: "vless"}
|
||||
b := &InboundConfig{Port: 443, Protocol: "trojan"}
|
||||
if a.Equals(b) {
|
||||
t.Error("InboundConfigs with different protocols should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfig_Equals_DifferentTag(t *testing.T) {
|
||||
a := &InboundConfig{Port: 443, Protocol: "vless", Tag: "tag-a"}
|
||||
b := &InboundConfig{Port: 443, Protocol: "vless", Tag: "tag-b"}
|
||||
if a.Equals(b) {
|
||||
t.Error("InboundConfigs with different tags should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfig_Equals_NilRawMessages(t *testing.T) {
|
||||
a := &InboundConfig{Port: 443, Protocol: "vless", Listen: nil, Settings: nil}
|
||||
b := &InboundConfig{Port: 443, Protocol: "vless", Listen: nil, Settings: nil}
|
||||
if !a.Equals(b) {
|
||||
t.Error("InboundConfigs with nil RawMessages should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfig_Equals_DifferentListen(t *testing.T) {
|
||||
a := &InboundConfig{Listen: []byte(`"0.0.0.0"`), Port: 443}
|
||||
b := &InboundConfig{Listen: []byte(`"127.0.0.1"`), Port: 443}
|
||||
if a.Equals(b) {
|
||||
t.Error("InboundConfigs with different Listen should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Equals_Equal(t *testing.T) {
|
||||
a := &Config{
|
||||
LogConfig: []byte(`{"loglevel":"info"}`),
|
||||
RouterConfig: []byte(`{}`),
|
||||
InboundConfigs: []InboundConfig{
|
||||
{Port: 443, Protocol: "vless"},
|
||||
},
|
||||
}
|
||||
b := &Config{
|
||||
LogConfig: []byte(`{"loglevel":"info"}`),
|
||||
RouterConfig: []byte(`{}`),
|
||||
InboundConfigs: []InboundConfig{
|
||||
{Port: 443, Protocol: "vless"},
|
||||
},
|
||||
}
|
||||
if !a.Equals(b) {
|
||||
t.Error("identical Configs should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Equals_DifferentInboundCount(t *testing.T) {
|
||||
a := &Config{
|
||||
InboundConfigs: []InboundConfig{{Port: 443}},
|
||||
}
|
||||
b := &Config{
|
||||
InboundConfigs: []InboundConfig{},
|
||||
}
|
||||
if a.Equals(b) {
|
||||
t.Error("Configs with different inbound counts should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Equals_DifferentLogConfig(t *testing.T) {
|
||||
a := &Config{LogConfig: []byte(`{"loglevel":"info"}`)}
|
||||
b := &Config{LogConfig: []byte(`{"loglevel":"debug"}`)}
|
||||
if a.Equals(b) {
|
||||
t.Error("Configs with different LogConfig should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Equals_EmptyConfigs(t *testing.T) {
|
||||
a := &Config{}
|
||||
b := &Config{}
|
||||
if !a.Equals(b) {
|
||||
t.Error("two empty Configs should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Equals_NilVsEmpty(t *testing.T) {
|
||||
a := &Config{}
|
||||
b := &Config{InboundConfigs: []InboundConfig{}}
|
||||
if !a.Equals(b) {
|
||||
t.Error("nil and empty slice InboundConfigs should be equal")
|
||||
}
|
||||
}
|
||||
91
xray/log_writer_test.go
Normal file
91
xray/log_writer_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package xray
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v2/logger"
|
||||
"github.com/op/go-logging"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize logger for tests that use LogWriter (which calls logger.Debug/Error/etc.)
|
||||
logger.InitLogger(logging.DEBUG)
|
||||
}
|
||||
|
||||
func TestNewLogWriter(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
if lw == nil {
|
||||
t.Fatal("NewLogWriter should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_CrashDetection(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
msg := []byte("panic: runtime error: index out of range")
|
||||
n, err := lw.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not return error: %v", err)
|
||||
}
|
||||
if n != len(msg) {
|
||||
t.Errorf("Write returned %d, expected %d", n, len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_FatalError(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
msg := []byte("fatal error: concurrent map writes")
|
||||
n, err := lw.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not return error: %v", err)
|
||||
}
|
||||
if n != len(msg) {
|
||||
t.Errorf("Write returned %d, expected %d", n, len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_Exception(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
msg := []byte("unhandled exception occurred")
|
||||
n, err := lw.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not return error: %v", err)
|
||||
}
|
||||
if n != len(msg) {
|
||||
t.Errorf("Write returned %d, expected %d", n, len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_EmptyMessage(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
n, err := lw.Write([]byte(""))
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not error: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Write returned %d, expected 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_TLSErrorSuppressed(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
msg := []byte("some tls handshake error occurred")
|
||||
n, err := lw.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not return error: %v", err)
|
||||
}
|
||||
if n != len(msg) {
|
||||
t.Errorf("Write returned %d, expected %d", n, len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_FailedKeyword(t *testing.T) {
|
||||
lw := NewLogWriter()
|
||||
msg := []byte("connection failed to remote")
|
||||
n, err := lw.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not return error: %v", err)
|
||||
}
|
||||
if n != len(msg) {
|
||||
t.Errorf("Write returned %d, expected %d", n, len(msg))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue