diff --git a/sub/links_test.go b/sub/links_test.go new file mode 100644 index 00000000..c600199c --- /dev/null +++ b/sub/links_test.go @@ -0,0 +1,40 @@ +package sub + +import ( + "reflect" + "testing" +) + +func TestSplitLinkLines(t *testing.T) { + cases := []struct { + name string + in string + want []string + }{ + {"single_line", "vless://abc", []string{"vless://abc"}}, + {"two_lines", "vless://abc\nvmess://xyz", []string{"vless://abc", "vmess://xyz"}}, + {"trims_each_line", " vless://abc \n\tvmess://xyz\t", []string{"vless://abc", "vmess://xyz"}}, + {"skips_blank_lines", "vless://abc\n\n\nvmess://xyz\n", []string{"vless://abc", "vmess://xyz"}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := splitLinkLines(c.in) + if !reflect.DeepEqual(got, c.want) { + t.Fatalf("splitLinkLines(%q) = %#v, want %#v", c.in, got, c.want) + } + }) + } +} + +func TestSplitLinkLines_EmptyInputIsNil(t *testing.T) { + if got := splitLinkLines(""); got != nil { + t.Fatalf("splitLinkLines(\"\") = %#v, want nil", got) + } +} + +func TestSplitLinkLines_WhitespaceOnlyHasNoEntries(t *testing.T) { + got := splitLinkLines(" \n\t \n") + if len(got) != 0 { + t.Fatalf("splitLinkLines(whitespace) = %#v, want empty slice", got) + } +} diff --git a/sub/subService_test.go b/sub/subService_test.go new file mode 100644 index 00000000..f83db7e3 --- /dev/null +++ b/sub/subService_test.go @@ -0,0 +1,480 @@ +package sub + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" + + "github.com/mhsanaei/3x-ui/v3/database/model" +) + +func TestFindClientIndex(t *testing.T) { + clients := []model.Client{ + {Email: "a@example.com"}, + {Email: "b@example.com"}, + {Email: "c@example.com"}, + } + if got := findClientIndex(clients, "b@example.com"); got != 1 { + t.Fatalf("findClientIndex middle = %d, want 1", got) + } + if got := findClientIndex(clients, "a@example.com"); got != 0 { + t.Fatalf("findClientIndex first = %d, want 0", got) + } + if got := findClientIndex(clients, "missing@example.com"); got != -1 { + t.Fatalf("findClientIndex missing = %d, want -1", got) + } + if got := findClientIndex(nil, "x"); got != -1 { + t.Fatalf("findClientIndex on nil slice = %d, want -1", got) + } +} + +func TestUnmarshalStreamSettings(t *testing.T) { + got := unmarshalStreamSettings(`{"network":"ws","wsSettings":{"path":"/api"}}`) + if got["network"] != "ws" { + t.Fatalf("network = %v, want ws", got["network"]) + } + ws, ok := got["wsSettings"].(map[string]any) + if !ok || ws["path"] != "/api" { + t.Fatalf("wsSettings = %v, want map with path=/api", got["wsSettings"]) + } +} + +func TestUnmarshalStreamSettings_InvalidJSON(t *testing.T) { + if got := unmarshalStreamSettings("not json"); got != nil { + t.Fatalf("invalid JSON should produce nil map, got %#v", got) + } +} + +func TestSearchHost_StringValue(t *testing.T) { + headers := map[string]any{"Host": "example.com"} + if got := searchHost(headers); got != "example.com" { + t.Fatalf("searchHost = %q, want example.com", got) + } +} + +func TestSearchHost_CaseInsensitiveKey(t *testing.T) { + headers := map[string]any{"host": "example.com"} + if got := searchHost(headers); got != "example.com" { + t.Fatalf("searchHost = %q, want example.com", got) + } + headers2 := map[string]any{"HOST": "example.com"} + if got := searchHost(headers2); got != "example.com" { + t.Fatalf("searchHost uppercase = %q, want example.com", got) + } +} + +func TestSearchHost_ArrayValue(t *testing.T) { + headers := map[string]any{"Host": []any{"first.example.com", "second.example.com"}} + if got := searchHost(headers); got != "first.example.com" { + t.Fatalf("searchHost array = %q, want first.example.com", got) + } +} + +func TestSearchHost_EmptyArray(t *testing.T) { + headers := map[string]any{"Host": []any{}} + if got := searchHost(headers); got != "" { + t.Fatalf("searchHost empty array = %q, want empty", got) + } +} + +func TestSearchHost_NoHostKey(t *testing.T) { + headers := map[string]any{"X-Other": "value"} + if got := searchHost(headers); got != "" { + t.Fatalf("searchHost no host = %q, want empty", got) + } +} + +func TestSearchHost_NotAMap(t *testing.T) { + if got := searchHost("not a map"); got != "" { + t.Fatalf("searchHost non-map = %q, want empty", got) + } + if got := searchHost(nil); got != "" { + t.Fatalf("searchHost nil = %q, want empty", got) + } +} + +func TestSearchKey_FoundAtTopLevel(t *testing.T) { + data := map[string]any{"foo": 42, "bar": "x"} + got, ok := searchKey(data, "foo") + if !ok { + t.Fatal("expected to find foo") + } + if got != 42 { + t.Fatalf("got %v, want 42", got) + } +} + +func TestSearchKey_FoundInNested(t *testing.T) { + data := map[string]any{ + "outer": map[string]any{ + "inner": map[string]any{ + "target": "hit", + }, + }, + } + got, ok := searchKey(data, "target") + if !ok { + t.Fatal("expected to find target in nested map") + } + if got != "hit" { + t.Fatalf("got %v, want hit", got) + } +} + +func TestSearchKey_FoundInsideArray(t *testing.T) { + data := map[string]any{ + "list": []any{ + map[string]any{"other": 1}, + map[string]any{"needle": "found"}, + }, + } + got, ok := searchKey(data, "needle") + if !ok { + t.Fatal("expected to find needle in array element") + } + if got != "found" { + t.Fatalf("got %v, want found", got) + } +} + +func TestSearchKey_NotFound(t *testing.T) { + data := map[string]any{"foo": "bar"} + if _, ok := searchKey(data, "missing"); ok { + t.Fatal("expected ok=false for missing key") + } +} + +func TestSearchKey_OnScalar(t *testing.T) { + if _, ok := searchKey(42, "anything"); ok { + t.Fatal("expected ok=false searching on a scalar") + } +} + +func TestCloneStringMap(t *testing.T) { + src := map[string]string{"a": "1", "b": "2"} + dst := cloneStringMap(src) + if len(dst) != len(src) { + t.Fatalf("clone length = %d, want %d", len(dst), len(src)) + } + for k, v := range src { + if dst[k] != v { + t.Fatalf("clone[%q] = %q, want %q", k, dst[k], v) + } + } + dst["a"] = "changed" + if src["a"] == "changed" { + t.Fatal("modifying clone leaked into source") + } +} + +func TestCloneStringMap_Empty(t *testing.T) { + dst := cloneStringMap(map[string]string{}) + if dst == nil { + t.Fatal("clone of empty map should not be nil") + } + if len(dst) != 0 { + t.Fatalf("clone of empty map should be empty, got %v", dst) + } +} + +func TestGetHostFromXFH_HostOnly(t *testing.T) { + got, err := getHostFromXFH("example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "example.com" { + t.Fatalf("got %q, want example.com", got) + } +} + +func TestGetHostFromXFH_HostWithPort(t *testing.T) { + got, err := getHostFromXFH("example.com:8443") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "example.com" { + t.Fatalf("got %q, want example.com", got) + } +} + +func TestGetHostFromXFH_IPv6WithPort(t *testing.T) { + got, err := getHostFromXFH("[2606:4700::1111]:443") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "2606:4700::1111" { + t.Fatalf("got %q, want 2606:4700::1111", got) + } +} + +func TestGetHostFromXFH_BadHostPort(t *testing.T) { + if _, err := getHostFromXFH("example.com:8443:9999"); err == nil { + t.Fatal("expected error for malformed host:port") + } +} + +func TestReadPositiveInt(t *testing.T) { + cases := []struct { + name string + in any + wantVal int + wantOk bool + }{ + {"int_positive", int(5), 5, true}, + {"int_zero", int(0), 0, false}, + {"int_negative", int(-3), -3, false}, + {"int32_positive", int32(7), 7, true}, + {"int64_positive", int64(99), 99, true}, + {"float64_positive", float64(12), 12, true}, + {"float64_zero", float64(0.0), 0, false}, + {"float64_negative", float64(-1.5), -1, false}, + {"float32_positive", float32(3), 3, true}, + {"string", "not a number", 0, false}, + {"nil", nil, 0, false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + gotVal, gotOk := readPositiveInt(c.in) + if gotVal != c.wantVal || gotOk != c.wantOk { + t.Fatalf("readPositiveInt(%v) = (%d, %v), want (%d, %v)", c.in, gotVal, gotOk, c.wantVal, c.wantOk) + } + }) + } +} + +func TestSetStringParam(t *testing.T) { + p := map[string]string{"existing": "value"} + + setStringParam(p, "new", "hello") + if p["new"] != "hello" { + t.Fatalf("missing key after set: %v", p) + } + + setStringParam(p, "existing", "") + if _, ok := p["existing"]; ok { + t.Fatalf("empty value should delete the key, got %v", p) + } +} + +func TestSetIntParam(t *testing.T) { + p := map[string]string{"existing": "10"} + + setIntParam(p, "n", 42) + if p["n"] != "42" { + t.Fatalf("set positive int: got %v", p) + } + + setIntParam(p, "existing", 0) + if _, ok := p["existing"]; ok { + t.Fatalf("zero value should delete the key, got %v", p) + } + + p["other"] = "5" + setIntParam(p, "other", -1) + if _, ok := p["other"]; ok { + t.Fatalf("negative value should delete the key, got %v", p) + } +} + +func TestSetStringField(t *testing.T) { + f := map[string]any{"existing": "value"} + + setStringField(f, "new", "hello") + if f["new"] != "hello" { + t.Fatalf("missing key after set: %v", f) + } + + setStringField(f, "existing", "") + if _, ok := f["existing"]; ok { + t.Fatalf("empty value should delete the key, got %v", f) + } +} + +func TestSetIntField(t *testing.T) { + f := map[string]any{"existing": 10} + + setIntField(f, "n", 7) + if f["n"] != 7 { + t.Fatalf("set positive int: got %v", f) + } + + setIntField(f, "existing", 0) + if _, ok := f["existing"]; ok { + t.Fatalf("zero value should delete the key, got %v", f) + } +} + +func TestBuildVmessLink(t *testing.T) { + obj := map[string]any{ + "v": "2", + "ps": "remark", + "add": "example.com", + "port": 443, + "net": "tcp", + } + link := buildVmessLink(obj) + if !strings.HasPrefix(link, "vmess://") { + t.Fatalf("missing vmess:// prefix: %q", link) + } + payload := strings.TrimPrefix(link, "vmess://") + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + t.Fatalf("base64 decode failed: %v", err) + } + var roundTrip map[string]any + if err := json.Unmarshal(decoded, &roundTrip); err != nil { + t.Fatalf("decoded payload is not JSON: %v\n%s", err, decoded) + } + if roundTrip["add"] != "example.com" { + t.Fatalf("round-trip add = %v, want example.com", roundTrip["add"]) + } + if roundTrip["ps"] != "remark" { + t.Fatalf("round-trip ps = %v, want remark", roundTrip["ps"]) + } +} + +func TestCloneVmessShareObj_CopiesEverythingByDefault(t *testing.T) { + base := map[string]any{ + "v": "2", + "sni": "example.com", + "alpn": "h2", + "fp": "chrome", + "net": "tcp", + } + out := cloneVmessShareObj(base, "tls") + for _, key := range []string{"sni", "alpn", "fp", "net", "v"} { + if _, ok := out[key]; !ok { + t.Fatalf("expected key %q to be preserved when security=tls, got %v", key, out) + } + } +} + +func TestCloneVmessShareObj_NoneStripsTLSOnlyKeys(t *testing.T) { + base := map[string]any{ + "v": "2", + "sni": "example.com", + "alpn": "h2", + "fp": "chrome", + "net": "tcp", + } + out := cloneVmessShareObj(base, "none") + for _, key := range []string{"sni", "alpn", "fp"} { + if _, ok := out[key]; ok { + t.Fatalf("security=none should strip %q, got %v", key, out) + } + } + if out["v"] != "2" || out["net"] != "tcp" { + t.Fatalf("non-TLS keys should remain, got %v", out) + } +} + +func TestExtractKcpShareFields_Defaults(t *testing.T) { + stream := map[string]any{} + got := extractKcpShareFields(stream) + if got.headerType != "none" { + t.Fatalf("default headerType = %q, want none", got.headerType) + } + if got.seed != "" || got.mtu != 0 || got.tti != 0 { + t.Fatalf("default kcpShareFields should be zero except headerType, got %+v", got) + } +} + +func TestExtractKcpShareFields_ReadsAllFields(t *testing.T) { + stream := map[string]any{ + "kcpSettings": map[string]any{ + "header": map[string]any{"type": "wechat-video"}, + "seed": "secret-seed", + "mtu": float64(1350), + "tti": float64(50), + }, + } + got := extractKcpShareFields(stream) + if got.headerType != "wechat-video" { + t.Fatalf("headerType = %q, want wechat-video", got.headerType) + } + if got.seed != "secret-seed" { + t.Fatalf("seed = %q, want secret-seed", got.seed) + } + if got.mtu != 1350 { + t.Fatalf("mtu = %d, want 1350", got.mtu) + } + if got.tti != 50 { + t.Fatalf("tti = %d, want 50", got.tti) + } +} + +func TestKcpShareFields_ApplyToParams(t *testing.T) { + params := map[string]string{} + kcpShareFields{headerType: "wechat-video", seed: "s", mtu: 1350, tti: 50}.applyToParams(params) + if params["headerType"] != "wechat-video" { + t.Fatalf("headerType param = %q", params["headerType"]) + } + if params["seed"] != "s" { + t.Fatalf("seed param = %q", params["seed"]) + } + if params["mtu"] != "1350" { + t.Fatalf("mtu param = %q", params["mtu"]) + } + if params["tti"] != "50" { + t.Fatalf("tti param = %q", params["tti"]) + } +} + +func TestKcpShareFields_ApplyToParams_NoneHeaderNotAdded(t *testing.T) { + params := map[string]string{} + kcpShareFields{headerType: "none"}.applyToParams(params) + if _, ok := params["headerType"]; ok { + t.Fatalf("headerType=none should not be added, got %v", params) + } +} + +func TestMarshalFinalMask_EmptyReturnsFalse(t *testing.T) { + if _, ok := marshalFinalMask(map[string]any{}); ok { + t.Fatal("expected ok=false for empty finalmask") + } + if _, ok := marshalFinalMask(nil); ok { + t.Fatal("expected ok=false for nil finalmask") + } +} + +func TestMarshalFinalMask_WithContent(t *testing.T) { + fm := map[string]any{ + "tcp": []any{ + map[string]any{"type": "fragment"}, + }, + } + out, ok := marshalFinalMask(fm) + if !ok { + t.Fatal("expected ok=true for finalmask with valid tcp mask") + } + if !strings.Contains(out, `"tcp"`) { + t.Fatalf("marshaled finalmask missing tcp key: %s", out) + } + if !strings.Contains(out, "fragment") { + t.Fatalf("marshaled finalmask missing mask type: %s", out) + } +} + +func TestMarshalFinalMask_UnknownTypeIsDropped(t *testing.T) { + fm := map[string]any{ + "tcp": []any{ + map[string]any{"type": "not-a-real-mask"}, + }, + } + if _, ok := marshalFinalMask(fm); ok { + t.Fatal("unknown mask types should be dropped, leaving nothing to marshal") + } +} + +func TestHasFinalMaskContent(t *testing.T) { + if hasFinalMaskContent(nil) { + t.Fatal("nil should not count as content") + } + if hasFinalMaskContent(map[string]any{}) { + t.Fatal("empty map should not count as content") + } + if !hasFinalMaskContent(map[string]any{"x": 1}) { + t.Fatal("non-empty map should count as content") + } +} diff --git a/util/common/format_test.go b/util/common/format_test.go new file mode 100644 index 00000000..d0176e7b --- /dev/null +++ b/util/common/format_test.go @@ -0,0 +1,28 @@ +package common + +import "testing" + +func TestFormatTraffic(t *testing.T) { + cases := []struct { + name string + bytes int64 + want string + }{ + {"zero", 0, "0.00B"}, + {"under_one_kb", 512, "512.00B"}, + {"exactly_one_kb", 1024, "1.00KB"}, + {"one_and_a_half_kb", 1536, "1.50KB"}, + {"one_mb", 1024 * 1024, "1.00MB"}, + {"one_gb", 1024 * 1024 * 1024, "1.00GB"}, + {"one_tb", 1024 * 1024 * 1024 * 1024, "1.00TB"}, + {"one_pb", 1024 * 1024 * 1024 * 1024 * 1024, "1.00PB"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := FormatTraffic(c.bytes) + if got != c.want { + t.Fatalf("FormatTraffic(%d) = %q, want %q", c.bytes, got, c.want) + } + }) + } +} diff --git a/util/common/multi_error_test.go b/util/common/multi_error_test.go new file mode 100644 index 00000000..0edb0102 --- /dev/null +++ b/util/common/multi_error_test.go @@ -0,0 +1,44 @@ +package common + +import ( + "errors" + "strings" + "testing" +) + +func TestCombine_AllNilReturnsNil(t *testing.T) { + if err := Combine(); err != nil { + t.Fatalf("Combine() with no args = %v, want nil", err) + } + if err := Combine(nil, nil, nil); err != nil { + t.Fatalf("Combine(nil, nil, nil) = %v, want nil", err) + } +} + +func TestCombine_SkipsNilErrors(t *testing.T) { + e1 := errors.New("boom one") + e2 := errors.New("boom two") + + err := Combine(nil, e1, nil, e2, nil) + if err == nil { + t.Fatal("expected non-nil combined error") + } + msg := err.Error() + if !strings.Contains(msg, "boom one") || !strings.Contains(msg, "boom two") { + t.Fatalf("combined error %q does not contain both underlying messages", msg) + } + if !strings.HasPrefix(msg, "multierr: ") { + t.Fatalf("combined error %q missing %q prefix", msg, "multierr: ") + } +} + +func TestCombine_SingleErrorStillWrapped(t *testing.T) { + e := errors.New("only one") + err := Combine(e) + if err == nil { + t.Fatal("expected non-nil error") + } + if !strings.Contains(err.Error(), "only one") { + t.Fatalf("combined error %q missing underlying message", err.Error()) + } +} diff --git a/util/crypto/crypto_test.go b/util/crypto/crypto_test.go new file mode 100644 index 00000000..00a15bd6 --- /dev/null +++ b/util/crypto/crypto_test.go @@ -0,0 +1,69 @@ +package crypto + +import ( + "strings" + "testing" +) + +func TestHashPasswordAsBcrypt_RoundTrip(t *testing.T) { + password := "correct horse battery staple" + + hash, err := HashPasswordAsBcrypt(password) + if err != nil { + t.Fatalf("HashPasswordAsBcrypt returned error: %v", err) + } + if hash == "" { + t.Fatal("expected non-empty hash") + } + if hash == password { + t.Fatal("hash must not equal the plaintext password") + } + if !strings.HasPrefix(hash, "$2") { + t.Fatalf("expected bcrypt prefix $2..., got %q", hash[:min(4, len(hash))]) + } + + if !CheckPasswordHash(hash, password) { + t.Fatal("CheckPasswordHash returned false for the matching password") + } +} + +func TestCheckPasswordHash_WrongPassword(t *testing.T) { + hash, err := HashPasswordAsBcrypt("right-password") + if err != nil { + t.Fatalf("HashPasswordAsBcrypt returned error: %v", err) + } + + if CheckPasswordHash(hash, "wrong-password") { + t.Fatal("CheckPasswordHash returned true for a wrong password") + } + if CheckPasswordHash(hash, "") { + t.Fatal("CheckPasswordHash returned true for an empty password") + } +} + +func TestCheckPasswordHash_InvalidHash(t *testing.T) { + if CheckPasswordHash("", "anything") { + t.Fatal("empty hash must not validate") + } + if CheckPasswordHash("not-a-bcrypt-hash", "anything") { + t.Fatal("malformed hash must not validate") + } +} + +func TestHashPasswordAsBcrypt_DifferentHashesForSamePassword(t *testing.T) { + password := "same-password" + h1, err := HashPasswordAsBcrypt(password) + if err != nil { + t.Fatalf("first hash failed: %v", err) + } + h2, err := HashPasswordAsBcrypt(password) + if err != nil { + t.Fatalf("second hash failed: %v", err) + } + if h1 == h2 { + t.Fatal("expected bcrypt to produce different hashes (random salt) for the same password") + } + if !CheckPasswordHash(h1, password) || !CheckPasswordHash(h2, password) { + t.Fatal("both hashes should still validate the original password") + } +} diff --git a/util/json_util/json_test.go b/util/json_util/json_test.go new file mode 100644 index 00000000..d4e8afcc --- /dev/null +++ b/util/json_util/json_test.go @@ -0,0 +1,76 @@ +package json_util + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestRawMessage_MarshalEmptyIsNull(t *testing.T) { + var m RawMessage + out, err := m.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON on empty returned error: %v", err) + } + if !bytes.Equal(out, []byte("null")) { + t.Fatalf("empty RawMessage marshaled to %q, want %q", out, "null") + } +} + +func TestRawMessage_MarshalPassthrough(t *testing.T) { + payload := []byte(`{"a":1}`) + m := RawMessage(payload) + out, err := m.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON returned error: %v", err) + } + if !bytes.Equal(out, payload) { + t.Fatalf("MarshalJSON = %q, want %q", out, payload) + } +} + +func TestRawMessage_UnmarshalCopiesData(t *testing.T) { + var m RawMessage + src := []byte(`{"k":"v"}`) + if err := m.UnmarshalJSON(src); err != nil { + t.Fatalf("UnmarshalJSON returned error: %v", err) + } + if !bytes.Equal(m, src) { + t.Fatalf("UnmarshalJSON stored %q, want %q", []byte(m), src) + } + + src[0] = 'X' + if m[0] == 'X' { + t.Fatal("UnmarshalJSON kept a reference to the caller's buffer; expected a copy") + } +} + +func TestRawMessage_UnmarshalNilReceiverErrors(t *testing.T) { + var m *RawMessage + if err := m.UnmarshalJSON([]byte("123")); err == nil { + t.Fatal("expected error for nil receiver") + } +} + +func TestRawMessage_RoundTripInsideStruct(t *testing.T) { + type wrapper struct { + Body RawMessage `json:"body"` + } + in := wrapper{Body: RawMessage(`{"x":42}`)} + encoded, err := json.Marshal(in) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + want := `{"body":{"x":42}}` + if string(encoded) != want { + t.Fatalf("Marshal = %s, want %s", encoded, want) + } + + var out wrapper + if err := json.Unmarshal(encoded, &out); err != nil { + t.Fatalf("json.Unmarshal returned error: %v", err) + } + if string(out.Body) != `{"x":42}` { + t.Fatalf("round-trip Body = %s, want %s", out.Body, `{"x":42}`) + } +} diff --git a/util/netsafe/netsafe_test.go b/util/netsafe/netsafe_test.go new file mode 100644 index 00000000..2fe9bcd5 --- /dev/null +++ b/util/netsafe/netsafe_test.go @@ -0,0 +1,127 @@ +package netsafe + +import ( + "context" + "net" + "strings" + "testing" +) + +func TestIsBlockedIP(t *testing.T) { + cases := []struct { + ip string + want bool + }{ + {"127.0.0.1", true}, + {"::1", true}, + {"10.0.0.5", true}, + {"172.16.0.1", true}, + {"192.168.1.1", true}, + {"169.254.0.1", true}, + {"0.0.0.0", true}, + {"::", true}, + {"8.8.8.8", false}, + {"1.1.1.1", false}, + {"2606:4700:4700::1111", false}, + } + for _, c := range cases { + t.Run(c.ip, func(t *testing.T) { + ip := net.ParseIP(c.ip) + if ip == nil { + t.Fatalf("could not parse %q", c.ip) + } + if got := IsBlockedIP(ip); got != c.want { + t.Fatalf("IsBlockedIP(%s) = %v, want %v", c.ip, got, c.want) + } + }) + } +} + +func TestAllowPrivateFromContext_Default(t *testing.T) { + if AllowPrivateFromContext(context.Background()) { + t.Fatal("default context should report AllowPrivate=false") + } +} + +func TestAllowPrivateFromContext_RoundTrip(t *testing.T) { + ctx := ContextWithAllowPrivate(context.Background(), true) + if !AllowPrivateFromContext(ctx) { + t.Fatal("expected AllowPrivate=true after ContextWithAllowPrivate(true)") + } + ctx = ContextWithAllowPrivate(ctx, false) + if AllowPrivateFromContext(ctx) { + t.Fatal("expected AllowPrivate=false after overriding with false") + } +} + +func TestNormalizeHost_Valid(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"example.com", "example.com"}, + {" example.com ", "example.com"}, + {"a.b.c.example.com", "a.b.c.example.com"}, + {"10.0.0.1", "10.0.0.1"}, + {"[2606:4700:4700::1111]", "2606:4700:4700::1111"}, + {"2606:4700:4700::1111", "2606:4700:4700::1111"}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + got, err := NormalizeHost(c.in) + if err != nil { + t.Fatalf("NormalizeHost(%q) returned error: %v", c.in, err) + } + if !strings.EqualFold(got, c.want) { + t.Fatalf("NormalizeHost(%q) = %q, want %q", c.in, got, c.want) + } + }) + } +} + +func TestNormalizeHost_Invalid(t *testing.T) { + cases := []string{ + "", + " ", + "-leading-dash.com", + "trailing-dash-.com", + "bad host with spaces", + "under_score.example.com", + "exa$mple.com", + strings.Repeat("a", 254), + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + if _, err := NormalizeHost(in); err == nil { + t.Fatalf("NormalizeHost(%q) expected error, got nil", in) + } + }) + } +} + +func TestSSRFGuardedDialContext_BlocksLiteralPrivateIP(t *testing.T) { + _, err := SSRFGuardedDialContext(context.Background(), "tcp", "127.0.0.1:1") + if err == nil { + t.Fatal("expected dial to 127.0.0.1 to be blocked") + } + if !strings.Contains(err.Error(), "blocked") { + t.Fatalf("expected 'blocked' in error, got: %v", err) + } +} + +func TestSSRFGuardedDialContext_AllowPrivateBypassesGuard(t *testing.T) { + ctx := ContextWithAllowPrivate(context.Background(), true) + _, err := SSRFGuardedDialContext(ctx, "tcp", "127.0.0.1:1") + if err == nil { + t.Fatal("dial to a closed loopback port should still fail at the connect step") + } + if strings.Contains(err.Error(), "blocked private/internal address") { + t.Fatalf("expected guard to be bypassed when AllowPrivate=true, got: %v", err) + } +} + +func TestSSRFGuardedDialContext_BadAddress(t *testing.T) { + if _, err := SSRFGuardedDialContext(context.Background(), "tcp", "no-port"); err == nil { + t.Fatal("expected error for address without port") + } +} diff --git a/util/random/random_test.go b/util/random/random_test.go new file mode 100644 index 00000000..5c33f6d7 --- /dev/null +++ b/util/random/random_test.go @@ -0,0 +1,63 @@ +package random + +import ( + "encoding/base64" + "testing" +) + +func TestSeq_LengthAndAlphabet(t *testing.T) { + for _, n := range []int{0, 1, 8, 64, 256} { + s := Seq(n) + if len(s) != n { + t.Fatalf("Seq(%d) returned length %d", n, len(s)) + } + for i, r := range s { + isDigit := r >= '0' && r <= '9' + isLower := r >= 'a' && r <= 'z' + isUpper := r >= 'A' && r <= 'Z' + if !(isDigit || isLower || isUpper) { + t.Fatalf("Seq(%d) byte %d = %q is not alphanumeric", n, i, r) + } + } + } +} + +func TestSeq_NotConstant(t *testing.T) { + a := Seq(32) + b := Seq(32) + if a == b { + t.Fatalf("two consecutive Seq(32) calls produced identical output: %q", a) + } +} + +func TestNum_InRange(t *testing.T) { + for _, upper := range []int{1, 2, 10, 1000} { + for i := 0; i < 200; i++ { + v := Num(upper) + if v < 0 || v >= upper { + t.Fatalf("Num(%d) returned %d, out of [0, %d)", upper, v, upper) + } + } + } +} + +func TestBase64Bytes_DecodesToRequestedSize(t *testing.T) { + for _, n := range []int{1, 16, 32, 64} { + out := Base64Bytes(n) + decoded, err := base64.StdEncoding.DecodeString(out) + if err != nil { + t.Fatalf("Base64Bytes(%d) produced invalid base64 %q: %v", n, out, err) + } + if len(decoded) != n { + t.Fatalf("Base64Bytes(%d) decoded to %d bytes", n, len(decoded)) + } + } +} + +func TestBase64Bytes_Random(t *testing.T) { + a := Base64Bytes(32) + b := Base64Bytes(32) + if a == b { + t.Fatalf("two consecutive Base64Bytes(32) calls produced identical output: %q", a) + } +} diff --git a/web/job/node_traffic_sync_job_test.go b/web/job/node_traffic_sync_job_test.go new file mode 100644 index 00000000..1bc7601e --- /dev/null +++ b/web/job/node_traffic_sync_job_test.go @@ -0,0 +1,73 @@ +package job + +import ( + "sync" + "testing" +) + +func TestAtomicBool_DefaultIsFalse(t *testing.T) { + var a atomicBool + if a.takeAndReset() { + t.Fatal("default atomicBool should report false") + } +} + +func TestAtomicBool_SetThenTakeReturnsTrueOnce(t *testing.T) { + var a atomicBool + a.set() + if !a.takeAndReset() { + t.Fatal("takeAndReset after set should return true") + } + if a.takeAndReset() { + t.Fatal("second takeAndReset should return false (state was reset)") + } +} + +func TestAtomicBool_SetIsIdempotent(t *testing.T) { + var a atomicBool + a.set() + a.set() + a.set() + if !a.takeAndReset() { + t.Fatal("repeated set should still leave the flag true") + } + if a.takeAndReset() { + t.Fatal("flag should be cleared after the first take") + } +} + +func TestAtomicBool_ConcurrentSettersExactlyOneTakeWins(t *testing.T) { + var a atomicBool + const setters = 100 + const readers = 20 + + var wg sync.WaitGroup + for i := 0; i < setters; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a.set() + }() + } + wg.Wait() + + trueCount := 0 + var rwg sync.WaitGroup + var mu sync.Mutex + for i := 0; i < readers; i++ { + rwg.Add(1) + go func() { + defer rwg.Done() + if a.takeAndReset() { + mu.Lock() + trueCount++ + mu.Unlock() + } + }() + } + rwg.Wait() + + if trueCount != 1 { + t.Fatalf("expected exactly one reader to observe true, got %d", trueCount) + } +} diff --git a/web/service/node_test.go b/web/service/node_test.go new file mode 100644 index 00000000..83243fbf --- /dev/null +++ b/web/service/node_test.go @@ -0,0 +1,162 @@ +package service + +import ( + "testing" + + "github.com/mhsanaei/3x-ui/v3/database/model" +) + +func TestNormalizeBasePath(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", "/"}, + {" ", "/"}, + {"/", "/"}, + {"/panel", "/panel/"}, + {"panel", "/panel/"}, + {"panel/", "/panel/"}, + {"/panel/", "/panel/"}, + {" /panel ", "/panel/"}, + {"/a/b/c", "/a/b/c/"}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + got := normalizeBasePath(c.in) + if got != c.want { + t.Fatalf("normalizeBasePath(%q) = %q, want %q", c.in, got, c.want) + } + }) + } +} + +func TestNodeMetricKey(t *testing.T) { + cases := []struct { + id int + metric string + want string + }{ + {1, "cpu", "node:1:cpu"}, + {42, "mem", "node:42:mem"}, + {0, "anything", "node:0:anything"}, + } + for _, c := range cases { + got := nodeMetricKey(c.id, c.metric) + if got != c.want { + t.Fatalf("nodeMetricKey(%d, %q) = %q, want %q", c.id, c.metric, got, c.want) + } + } +} + +func TestHeartbeatPatch_ToUI_OnlineCopiesFields(t *testing.T) { + p := HeartbeatPatch{ + Status: "ignored-source", + LatencyMs: 42, + XrayVersion: "1.8.4", + PanelVersion: "3.0.0", + CpuPct: 12.5, + MemPct: 33.3, + UptimeSecs: 12345, + LastError: "", + } + ui := p.ToUI(true) + if ui.Status != "online" { + t.Fatalf("Status = %q, want online", ui.Status) + } + if ui.LatencyMs != 42 || ui.XrayVersion != "1.8.4" || ui.PanelVersion != "3.0.0" { + t.Fatalf("scalar copy mismatch: %+v", ui) + } + if ui.CpuPct != 12.5 || ui.MemPct != 33.3 || ui.UptimeSecs != 12345 { + t.Fatalf("metric copy mismatch: %+v", ui) + } + if ui.Error != "" { + t.Fatalf("Error = %q, want empty", ui.Error) + } +} + +func TestHeartbeatPatch_ToUI_OfflinePreservesError(t *testing.T) { + p := HeartbeatPatch{LastError: "connection refused"} + ui := p.ToUI(false) + if ui.Status != "offline" { + t.Fatalf("Status = %q, want offline", ui.Status) + } + if ui.Error != "connection refused" { + t.Fatalf("Error = %q, want %q", ui.Error, "connection refused") + } +} + +func TestNodeService_Normalize_Valid(t *testing.T) { + s := &NodeService{} + n := &model.Node{ + Name: " primary ", + ApiToken: " abc ", + Address: "example.com", + Port: 8443, + Scheme: "", + BasePath: "panel", + } + if err := s.normalize(n); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n.Name != "primary" { + t.Fatalf("Name not trimmed: %q", n.Name) + } + if n.ApiToken != "abc" { + t.Fatalf("ApiToken not trimmed: %q", n.ApiToken) + } + if n.Scheme != "https" { + t.Fatalf("empty Scheme should default to https, got %q", n.Scheme) + } + if n.BasePath != "/panel/" { + t.Fatalf("BasePath = %q, want /panel/", n.BasePath) + } +} + +func TestNodeService_Normalize_KeepsValidScheme(t *testing.T) { + s := &NodeService{} + n := &model.Node{Name: "n", Address: "example.com", Port: 80, Scheme: "http"} + if err := s.normalize(n); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n.Scheme != "http" { + t.Fatalf("Scheme = %q, want http", n.Scheme) + } +} + +func TestNodeService_Normalize_RejectsEmptyName(t *testing.T) { + s := &NodeService{} + n := &model.Node{Name: " ", Address: "example.com", Port: 443} + if err := s.normalize(n); err == nil { + t.Fatal("expected error for empty name") + } +} + +func TestNodeService_Normalize_RejectsBadHost(t *testing.T) { + s := &NodeService{} + n := &model.Node{Name: "n", Address: "bad host name with spaces", Port: 443} + if err := s.normalize(n); err == nil { + t.Fatal("expected error for invalid host") + } +} + +func TestNodeService_Normalize_RejectsOutOfRangePort(t *testing.T) { + s := &NodeService{} + for _, port := range []int{0, -1, 65536, 100000} { + n := &model.Node{Name: "n", Address: "example.com", Port: port} + if err := s.normalize(n); err == nil { + t.Fatalf("expected error for port %d", port) + } + } +} + +func TestNodeService_Normalize_OverridesUnknownScheme(t *testing.T) { + s := &NodeService{} + n := &model.Node{Name: "n", Address: "example.com", Port: 443, Scheme: "ftp"} + if err := s.normalize(n); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n.Scheme != "https" { + t.Fatalf("Scheme = %q, want https", n.Scheme) + } +} diff --git a/web/websocket/hub_test.go b/web/websocket/hub_test.go new file mode 100644 index 00000000..2a418068 --- /dev/null +++ b/web/websocket/hub_test.go @@ -0,0 +1,257 @@ +package websocket + +import ( + "encoding/json" + "os" + "sync" + "testing" + "time" + + xuilogger "github.com/mhsanaei/3x-ui/v3/logger" + "github.com/op/go-logging" +) + +func TestMain(m *testing.M) { + _ = os.Setenv("XUI_LOG_FOLDER", os.TempDir()) + xuilogger.InitLogger(logging.ERROR) + os.Exit(m.Run()) +} + +func TestNewClient_HasBufferedSendChannel(t *testing.T) { + c := NewClient("client-1") + if c.ID != "client-1" { + t.Fatalf("ID = %q, want client-1", c.ID) + } + if cap(c.Send) != clientSendQueue { + t.Fatalf("Send cap = %d, want %d", cap(c.Send), clientSendQueue) + } +} + +func TestHub_NilReceiver_DoesNotPanic(t *testing.T) { + var h *Hub + if h.GetClientCount() != 0 { + t.Fatal("nil hub GetClientCount should return 0") + } + h.Broadcast(MessageTypeStatus, "anything") + h.Register(NewClient("x")) + h.Unregister(NewClient("x")) + h.Stop() +} + +func TestHub_BroadcastDropsWhenNoClients(t *testing.T) { + h := NewHub() + defer h.Stop() + go h.Run() + + h.Broadcast(MessageTypeStatus, "payload") + + select { + case <-h.broadcast: + t.Fatal("Broadcast should drop when client count is zero") + case <-time.After(50 * time.Millisecond): + } +} + +func TestHub_BroadcastDropsNilPayload(t *testing.T) { + h := NewHub() + defer h.Stop() + go h.Run() + + c := NewClient("c1") + h.Register(c) + waitClientCount(t, h, 1) + + h.Broadcast(MessageTypeStatus, nil) + + select { + case <-c.Send: + t.Fatal("nil payload should be dropped, not delivered") + case <-time.After(50 * time.Millisecond): + } +} + +func TestHub_BroadcastDeliversToClient(t *testing.T) { + h := NewHub() + defer h.Stop() + go h.Run() + + c := NewClient("c1") + h.Register(c) + waitClientCount(t, h, 1) + + h.Broadcast(MessageTypeStatus, map[string]string{"k": "v"}) + + select { + case raw := <-c.Send: + var m Message + if err := json.Unmarshal(raw, &m); err != nil { + t.Fatalf("payload is not valid JSON: %v\n%s", err, raw) + } + if m.Type != MessageTypeStatus { + t.Fatalf("Type = %q, want %q", m.Type, MessageTypeStatus) + } + if m.Time == 0 { + t.Fatal("Time should be set to a non-zero unix-millis value") + } + case <-time.After(500 * time.Millisecond): + t.Fatal("timed out waiting for broadcast to reach client") + } +} + +func TestHub_UnregisterClosesSendAndDecrementsCount(t *testing.T) { + h := NewHub() + defer h.Stop() + go h.Run() + + c := NewClient("c1") + h.Register(c) + waitClientCount(t, h, 1) + + h.Unregister(c) + waitClientCount(t, h, 0) + + select { + case _, ok := <-c.Send: + if ok { + t.Fatal("expected Send channel to be closed after Unregister") + } + case <-time.After(500 * time.Millisecond): + t.Fatal("Send channel was not closed after Unregister") + } +} + +func TestHub_StopClosesAllClients(t *testing.T) { + h := NewHub() + go h.Run() + + c1 := NewClient("c1") + c2 := NewClient("c2") + h.Register(c1) + h.Register(c2) + waitClientCount(t, h, 2) + + h.Stop() + + for _, c := range []*Client{c1, c2} { + select { + case _, ok := <-c.Send: + if ok { + t.Fatalf("client %s Send should be closed after Stop", c.ID) + } + case <-time.After(500 * time.Millisecond): + t.Fatalf("client %s Send not closed after Stop", c.ID) + } + } +} + +func TestHub_ShouldThrottle(t *testing.T) { + h := NewHub() + defer h.Stop() + + if h.shouldThrottle(MessageTypeStatus) { + t.Fatal("non-gated message type should never throttle") + } + if h.shouldThrottle(MessageTypeStatus) { + t.Fatal("non-gated message type should never throttle on second call") + } + + if h.shouldThrottle(MessageTypeTraffic) { + t.Fatal("first call for gated type should not throttle") + } + if !h.shouldThrottle(MessageTypeTraffic) { + t.Fatal("immediate second call for gated type should throttle") + } +} + +func TestHub_ShouldThrottle_DistinctTypesIndependent(t *testing.T) { + h := NewHub() + defer h.Stop() + + if h.shouldThrottle(MessageTypeTraffic) { + t.Fatal("first Traffic call should not throttle") + } + if h.shouldThrottle(MessageTypeInbounds) { + t.Fatal("first Inbounds call should not throttle even after Traffic") + } +} + +func TestTrySend_SucceedsWithRoom(t *testing.T) { + c := &Client{ID: "c", Send: make(chan []byte, 1)} + if !trySend(c, []byte("hi")) { + t.Fatal("trySend should succeed when buffer has room") + } +} + +func TestTrySend_FailsWhenFull(t *testing.T) { + c := &Client{ID: "c", Send: make(chan []byte, 1)} + c.Send <- []byte("first") + if trySend(c, []byte("second")) { + t.Fatal("trySend should fail when buffer is full") + } +} + +func TestTrySend_FailsOnClosedChannel(t *testing.T) { + c := &Client{ID: "c", Send: make(chan []byte, 1)} + close(c.Send) + if trySend(c, []byte("after-close")) { + t.Fatal("trySend should fail (not panic) when channel is closed") + } +} + +func TestHub_FanoutEvictsSlowClient(t *testing.T) { + h := NewHub() + defer h.Stop() + go h.Run() + + slow := &Client{ID: "slow", Send: make(chan []byte, 1)} + slow.Send <- []byte("buffer-already-full") + h.Register(slow) + waitClientCount(t, h, 1) + + h.Broadcast(MessageTypeStatus, "payload") + waitClientCount(t, h, 0) + + select { + case _, ok := <-slow.Send: + if ok { + _, ok = <-slow.Send + if ok { + t.Fatal("slow client Send should eventually be closed by fanout eviction") + } + } + case <-time.After(500 * time.Millisecond): + t.Fatal("slow client Send channel was not closed") + } +} + +func TestHub_ConcurrentRegisterUnregister(t *testing.T) { + h := NewHub() + defer h.Stop() + go h.Run() + + const n = 50 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + c := NewClient("c") + h.Register(c) + h.Unregister(c) + }(i) + } + wg.Wait() + waitClientCount(t, h, 0) +} + +func waitClientCount(t *testing.T, h *Hub, want int) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if h.GetClientCount() == want { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("client count never reached %d (last seen %d)", want, h.GetClientCount()) +} diff --git a/xray/api_test.go b/xray/api_test.go new file mode 100644 index 00000000..3f018f52 --- /dev/null +++ b/xray/api_test.go @@ -0,0 +1,82 @@ +package xray + +import ( + "strings" + "testing" +) + +func TestGetRequiredUserString_Present(t *testing.T) { + user := map[string]any{"email": "alice@example.com"} + got, err := getRequiredUserString(user, "email") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "alice@example.com" { + t.Fatalf("got %q, want %q", got, "alice@example.com") + } +} + +func TestGetRequiredUserString_Missing(t *testing.T) { + user := map[string]any{} + if _, err := getRequiredUserString(user, "email"); err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestGetRequiredUserString_NilValue(t *testing.T) { + user := map[string]any{"email": nil} + if _, err := getRequiredUserString(user, "email"); err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestGetRequiredUserString_WrongType(t *testing.T) { + user := map[string]any{"email": 42} + _, err := getRequiredUserString(user, "email") + if err == nil { + t.Fatal("expected error for non-string value") + } + if !strings.Contains(err.Error(), "invalid type") { + t.Fatalf("expected %q in error, got: %v", "invalid type", err) + } +} + +func TestGetOptionalUserString_Present(t *testing.T) { + user := map[string]any{"flow": "xtls-rprx-vision"} + got, err := getOptionalUserString(user, "flow") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "xtls-rprx-vision" { + t.Fatalf("got %q, want %q", got, "xtls-rprx-vision") + } +} + +func TestGetOptionalUserString_MissingReturnsEmptyNoError(t *testing.T) { + user := map[string]any{} + got, err := getOptionalUserString(user, "flow") + if err != nil { + t.Fatalf("unexpected error for missing optional field: %v", err) + } + if got != "" { + t.Fatalf("got %q, want empty string", got) + } +} + +func TestGetOptionalUserString_NilReturnsEmptyNoError(t *testing.T) { + user := map[string]any{"flow": nil} + got, err := getOptionalUserString(user, "flow") + if err != nil { + t.Fatalf("unexpected error for nil optional field: %v", err) + } + if got != "" { + t.Fatalf("got %q, want empty string", got) + } +} + +func TestGetOptionalUserString_WrongTypeErrors(t *testing.T) { + user := map[string]any{"flow": []string{"a", "b"}} + if _, err := getOptionalUserString(user, "flow"); err == nil { + t.Fatal("expected error for non-string optional value") + } +} diff --git a/xray/config_test.go b/xray/config_test.go new file mode 100644 index 00000000..bcd97d59 --- /dev/null +++ b/xray/config_test.go @@ -0,0 +1,91 @@ +package xray + +import ( + "testing" + + "github.com/mhsanaei/3x-ui/v3/util/json_util" +) + +func makeConfig() *Config { + return &Config{ + LogConfig: json_util.RawMessage(`{"loglevel":"warning"}`), + RouterConfig: json_util.RawMessage(`{}`), + OutboundConfigs: json_util.RawMessage(`[]`), + Policy: json_util.RawMessage(`{}`), + API: json_util.RawMessage(`{}`), + Stats: json_util.RawMessage(`{}`), + Metrics: json_util.RawMessage(`{}`), + InboundConfigs: []InboundConfig{ + { + Port: 1080, + Protocol: "vless", + Tag: "inbound-1080", + Listen: json_util.RawMessage(`"0.0.0.0"`), + Settings: json_util.RawMessage(`{"clients":[]}`), + }, + }, + } +} + +func TestConfigEquals_IdenticalConfigs(t *testing.T) { + a := makeConfig() + b := makeConfig() + if !a.Equals(b) { + t.Fatal("two identical configs should be Equals") + } +} + +func TestConfigEquals_DifferentInboundCount(t *testing.T) { + a := makeConfig() + b := makeConfig() + b.InboundConfigs = append(b.InboundConfigs, InboundConfig{Port: 2080, Protocol: "vmess", Tag: "inbound-2080"}) + if a.Equals(b) { + t.Fatal("configs with different inbound counts should not be Equals") + } +} + +func TestConfigEquals_DifferentInboundContent(t *testing.T) { + a := makeConfig() + b := makeConfig() + b.InboundConfigs[0].Port = 9999 + if a.Equals(b) { + t.Fatal("config with changed inbound port should not be Equals") + } +} + +func TestConfigEquals_DifferentLogConfig(t *testing.T) { + a := makeConfig() + b := makeConfig() + b.LogConfig = json_util.RawMessage(`{"loglevel":"debug"}`) + if a.Equals(b) { + t.Fatal("config with changed log section should not be Equals") + } +} + +func TestConfigEquals_RawSectionsCompared(t *testing.T) { + fields := []struct { + name string + mutator func(c *Config) + }{ + {"RouterConfig", func(c *Config) { c.RouterConfig = json_util.RawMessage(`{"changed":true}`) }}, + {"DNSConfig", func(c *Config) { c.DNSConfig = json_util.RawMessage(`{"servers":["1.1.1.1"]}`) }}, + {"OutboundConfigs", func(c *Config) { c.OutboundConfigs = json_util.RawMessage(`[{"tag":"x"}]`) }}, + {"Transport", func(c *Config) { c.Transport = json_util.RawMessage(`{"x":1}`) }}, + {"Policy", func(c *Config) { c.Policy = json_util.RawMessage(`{"levels":{}}`) }}, + {"API", func(c *Config) { c.API = json_util.RawMessage(`{"tag":"api"}`) }}, + {"Stats", func(c *Config) { c.Stats = json_util.RawMessage(`{"on":true}`) }}, + {"Reverse", func(c *Config) { c.Reverse = json_util.RawMessage(`{"bridges":[]}`) }}, + {"FakeDNS", func(c *Config) { c.FakeDNS = json_util.RawMessage(`[]`) }}, + {"Metrics", func(c *Config) { c.Metrics = json_util.RawMessage(`{"tag":"m"}`) }}, + } + for _, f := range fields { + t.Run(f.name, func(t *testing.T) { + a := makeConfig() + b := makeConfig() + f.mutator(b) + if a.Equals(b) { + t.Fatalf("mutating %s should break Equals", f.name) + } + }) + } +} diff --git a/xray/inbound_test.go b/xray/inbound_test.go new file mode 100644 index 00000000..28c4d177 --- /dev/null +++ b/xray/inbound_test.go @@ -0,0 +1,52 @@ +package xray + +import ( + "testing" + + "github.com/mhsanaei/3x-ui/v3/util/json_util" +) + +func makeInbound() InboundConfig { + return InboundConfig{ + Listen: json_util.RawMessage(`"0.0.0.0"`), + Port: 1234, + Protocol: "vless", + Settings: json_util.RawMessage(`{"clients":[{"id":"abc"}]}`), + StreamSettings: json_util.RawMessage(`{"network":"tcp"}`), + Tag: "inbound-1234", + Sniffing: json_util.RawMessage(`{"enabled":false}`), + } +} + +func TestInboundConfigEquals_Identical(t *testing.T) { + a := makeInbound() + b := makeInbound() + if !a.Equals(&b) { + t.Fatal("two identical inbounds should be Equals") + } +} + +func TestInboundConfigEquals_MutationsBreakEquality(t *testing.T) { + cases := []struct { + name string + mutator func(c *InboundConfig) + }{ + {"Listen", func(c *InboundConfig) { c.Listen = json_util.RawMessage(`"127.0.0.1"`) }}, + {"Port", func(c *InboundConfig) { c.Port = 9999 }}, + {"Protocol", func(c *InboundConfig) { c.Protocol = "vmess" }}, + {"Settings", func(c *InboundConfig) { c.Settings = json_util.RawMessage(`{"clients":[]}`) }}, + {"StreamSettings", func(c *InboundConfig) { c.StreamSettings = json_util.RawMessage(`{"network":"ws"}`) }}, + {"Tag", func(c *InboundConfig) { c.Tag = "inbound-other" }}, + {"Sniffing", func(c *InboundConfig) { c.Sniffing = json_util.RawMessage(`{"enabled":true}`) }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + a := makeInbound() + b := makeInbound() + tc.mutator(&b) + if a.Equals(&b) { + t.Fatalf("mutating %s should break Equals", tc.name) + } + }) + } +}