mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-05 20:54:14 +00:00
test: cover crypto, random, netsafe, sub helpers, xray equals, websocket hub, node service
Adds ~110 unit tests across previously untested packages. Focus on pure-logic and concurrency surfaces where regressions would silently affect users: - util/crypto, util/random: password hashing round-trip, ss2022 key generation, alphabet/length invariants. - util/netsafe: IsBlockedIP edge cases, NormalizeHost validation, SSRF guard with AllowPrivate context bypass. - util/common, util/json_util: traffic formatter, Combine nil-skip, RawMessage empty-as-null and copy-on-unmarshal. - sub: splitLinkLines, searchKey/searchHost, kcp share fields, finalmask normalization, buildVmessLink round-trip. - xray: Config.Equals and InboundConfig.Equals field-by-field, getRequiredUserString/getOptionalUserString type checks. - web/websocket: hub registration, throttling, slow-client eviction, nil-receiver safety, concurrent register/unregister. - web/service: NodeService.normalize validation, normalizeBasePath, HeartbeatPatch.ToUI mapping. - web/job: atomicBool concurrent set/takeAndReset semantics. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
93706702c5
commit
106adca414
14 changed files with 1644 additions and 0 deletions
40
sub/links_test.go
Normal file
40
sub/links_test.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package sub
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitLinkLines(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want []string
|
||||
}{
|
||||
{"single_line", "vless://abc", []string{"vless://abc"}},
|
||||
{"two_lines", "vless://abc\nvmess://xyz", []string{"vless://abc", "vmess://xyz"}},
|
||||
{"trims_each_line", " vless://abc \n\tvmess://xyz\t", []string{"vless://abc", "vmess://xyz"}},
|
||||
{"skips_blank_lines", "vless://abc\n\n\nvmess://xyz\n", []string{"vless://abc", "vmess://xyz"}},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got := splitLinkLines(c.in)
|
||||
if !reflect.DeepEqual(got, c.want) {
|
||||
t.Fatalf("splitLinkLines(%q) = %#v, want %#v", c.in, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLinkLines_EmptyInputIsNil(t *testing.T) {
|
||||
if got := splitLinkLines(""); got != nil {
|
||||
t.Fatalf("splitLinkLines(\"\") = %#v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLinkLines_WhitespaceOnlyHasNoEntries(t *testing.T) {
|
||||
got := splitLinkLines(" \n\t \n")
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("splitLinkLines(whitespace) = %#v, want empty slice", got)
|
||||
}
|
||||
}
|
||||
480
sub/subService_test.go
Normal file
480
sub/subService_test.go
Normal file
|
|
@ -0,0 +1,480 @@
|
|||
package sub
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||
)
|
||||
|
||||
func TestFindClientIndex(t *testing.T) {
|
||||
clients := []model.Client{
|
||||
{Email: "a@example.com"},
|
||||
{Email: "b@example.com"},
|
||||
{Email: "c@example.com"},
|
||||
}
|
||||
if got := findClientIndex(clients, "b@example.com"); got != 1 {
|
||||
t.Fatalf("findClientIndex middle = %d, want 1", got)
|
||||
}
|
||||
if got := findClientIndex(clients, "a@example.com"); got != 0 {
|
||||
t.Fatalf("findClientIndex first = %d, want 0", got)
|
||||
}
|
||||
if got := findClientIndex(clients, "missing@example.com"); got != -1 {
|
||||
t.Fatalf("findClientIndex missing = %d, want -1", got)
|
||||
}
|
||||
if got := findClientIndex(nil, "x"); got != -1 {
|
||||
t.Fatalf("findClientIndex on nil slice = %d, want -1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStreamSettings(t *testing.T) {
|
||||
got := unmarshalStreamSettings(`{"network":"ws","wsSettings":{"path":"/api"}}`)
|
||||
if got["network"] != "ws" {
|
||||
t.Fatalf("network = %v, want ws", got["network"])
|
||||
}
|
||||
ws, ok := got["wsSettings"].(map[string]any)
|
||||
if !ok || ws["path"] != "/api" {
|
||||
t.Fatalf("wsSettings = %v, want map with path=/api", got["wsSettings"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStreamSettings_InvalidJSON(t *testing.T) {
|
||||
if got := unmarshalStreamSettings("not json"); got != nil {
|
||||
t.Fatalf("invalid JSON should produce nil map, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchHost_StringValue(t *testing.T) {
|
||||
headers := map[string]any{"Host": "example.com"}
|
||||
if got := searchHost(headers); got != "example.com" {
|
||||
t.Fatalf("searchHost = %q, want example.com", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchHost_CaseInsensitiveKey(t *testing.T) {
|
||||
headers := map[string]any{"host": "example.com"}
|
||||
if got := searchHost(headers); got != "example.com" {
|
||||
t.Fatalf("searchHost = %q, want example.com", got)
|
||||
}
|
||||
headers2 := map[string]any{"HOST": "example.com"}
|
||||
if got := searchHost(headers2); got != "example.com" {
|
||||
t.Fatalf("searchHost uppercase = %q, want example.com", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchHost_ArrayValue(t *testing.T) {
|
||||
headers := map[string]any{"Host": []any{"first.example.com", "second.example.com"}}
|
||||
if got := searchHost(headers); got != "first.example.com" {
|
||||
t.Fatalf("searchHost array = %q, want first.example.com", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchHost_EmptyArray(t *testing.T) {
|
||||
headers := map[string]any{"Host": []any{}}
|
||||
if got := searchHost(headers); got != "" {
|
||||
t.Fatalf("searchHost empty array = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchHost_NoHostKey(t *testing.T) {
|
||||
headers := map[string]any{"X-Other": "value"}
|
||||
if got := searchHost(headers); got != "" {
|
||||
t.Fatalf("searchHost no host = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchHost_NotAMap(t *testing.T) {
|
||||
if got := searchHost("not a map"); got != "" {
|
||||
t.Fatalf("searchHost non-map = %q, want empty", got)
|
||||
}
|
||||
if got := searchHost(nil); got != "" {
|
||||
t.Fatalf("searchHost nil = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKey_FoundAtTopLevel(t *testing.T) {
|
||||
data := map[string]any{"foo": 42, "bar": "x"}
|
||||
got, ok := searchKey(data, "foo")
|
||||
if !ok {
|
||||
t.Fatal("expected to find foo")
|
||||
}
|
||||
if got != 42 {
|
||||
t.Fatalf("got %v, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKey_FoundInNested(t *testing.T) {
|
||||
data := map[string]any{
|
||||
"outer": map[string]any{
|
||||
"inner": map[string]any{
|
||||
"target": "hit",
|
||||
},
|
||||
},
|
||||
}
|
||||
got, ok := searchKey(data, "target")
|
||||
if !ok {
|
||||
t.Fatal("expected to find target in nested map")
|
||||
}
|
||||
if got != "hit" {
|
||||
t.Fatalf("got %v, want hit", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKey_FoundInsideArray(t *testing.T) {
|
||||
data := map[string]any{
|
||||
"list": []any{
|
||||
map[string]any{"other": 1},
|
||||
map[string]any{"needle": "found"},
|
||||
},
|
||||
}
|
||||
got, ok := searchKey(data, "needle")
|
||||
if !ok {
|
||||
t.Fatal("expected to find needle in array element")
|
||||
}
|
||||
if got != "found" {
|
||||
t.Fatalf("got %v, want found", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKey_NotFound(t *testing.T) {
|
||||
data := map[string]any{"foo": "bar"}
|
||||
if _, ok := searchKey(data, "missing"); ok {
|
||||
t.Fatal("expected ok=false for missing key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKey_OnScalar(t *testing.T) {
|
||||
if _, ok := searchKey(42, "anything"); ok {
|
||||
t.Fatal("expected ok=false searching on a scalar")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneStringMap(t *testing.T) {
|
||||
src := map[string]string{"a": "1", "b": "2"}
|
||||
dst := cloneStringMap(src)
|
||||
if len(dst) != len(src) {
|
||||
t.Fatalf("clone length = %d, want %d", len(dst), len(src))
|
||||
}
|
||||
for k, v := range src {
|
||||
if dst[k] != v {
|
||||
t.Fatalf("clone[%q] = %q, want %q", k, dst[k], v)
|
||||
}
|
||||
}
|
||||
dst["a"] = "changed"
|
||||
if src["a"] == "changed" {
|
||||
t.Fatal("modifying clone leaked into source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneStringMap_Empty(t *testing.T) {
|
||||
dst := cloneStringMap(map[string]string{})
|
||||
if dst == nil {
|
||||
t.Fatal("clone of empty map should not be nil")
|
||||
}
|
||||
if len(dst) != 0 {
|
||||
t.Fatalf("clone of empty map should be empty, got %v", dst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHostFromXFH_HostOnly(t *testing.T) {
|
||||
got, err := getHostFromXFH("example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "example.com" {
|
||||
t.Fatalf("got %q, want example.com", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHostFromXFH_HostWithPort(t *testing.T) {
|
||||
got, err := getHostFromXFH("example.com:8443")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "example.com" {
|
||||
t.Fatalf("got %q, want example.com", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHostFromXFH_IPv6WithPort(t *testing.T) {
|
||||
got, err := getHostFromXFH("[2606:4700::1111]:443")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "2606:4700::1111" {
|
||||
t.Fatalf("got %q, want 2606:4700::1111", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHostFromXFH_BadHostPort(t *testing.T) {
|
||||
if _, err := getHostFromXFH("example.com:8443:9999"); err == nil {
|
||||
t.Fatal("expected error for malformed host:port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPositiveInt(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in any
|
||||
wantVal int
|
||||
wantOk bool
|
||||
}{
|
||||
{"int_positive", int(5), 5, true},
|
||||
{"int_zero", int(0), 0, false},
|
||||
{"int_negative", int(-3), -3, false},
|
||||
{"int32_positive", int32(7), 7, true},
|
||||
{"int64_positive", int64(99), 99, true},
|
||||
{"float64_positive", float64(12), 12, true},
|
||||
{"float64_zero", float64(0.0), 0, false},
|
||||
{"float64_negative", float64(-1.5), -1, false},
|
||||
{"float32_positive", float32(3), 3, true},
|
||||
{"string", "not a number", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
gotVal, gotOk := readPositiveInt(c.in)
|
||||
if gotVal != c.wantVal || gotOk != c.wantOk {
|
||||
t.Fatalf("readPositiveInt(%v) = (%d, %v), want (%d, %v)", c.in, gotVal, gotOk, c.wantVal, c.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetStringParam(t *testing.T) {
|
||||
p := map[string]string{"existing": "value"}
|
||||
|
||||
setStringParam(p, "new", "hello")
|
||||
if p["new"] != "hello" {
|
||||
t.Fatalf("missing key after set: %v", p)
|
||||
}
|
||||
|
||||
setStringParam(p, "existing", "")
|
||||
if _, ok := p["existing"]; ok {
|
||||
t.Fatalf("empty value should delete the key, got %v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIntParam(t *testing.T) {
|
||||
p := map[string]string{"existing": "10"}
|
||||
|
||||
setIntParam(p, "n", 42)
|
||||
if p["n"] != "42" {
|
||||
t.Fatalf("set positive int: got %v", p)
|
||||
}
|
||||
|
||||
setIntParam(p, "existing", 0)
|
||||
if _, ok := p["existing"]; ok {
|
||||
t.Fatalf("zero value should delete the key, got %v", p)
|
||||
}
|
||||
|
||||
p["other"] = "5"
|
||||
setIntParam(p, "other", -1)
|
||||
if _, ok := p["other"]; ok {
|
||||
t.Fatalf("negative value should delete the key, got %v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetStringField(t *testing.T) {
|
||||
f := map[string]any{"existing": "value"}
|
||||
|
||||
setStringField(f, "new", "hello")
|
||||
if f["new"] != "hello" {
|
||||
t.Fatalf("missing key after set: %v", f)
|
||||
}
|
||||
|
||||
setStringField(f, "existing", "")
|
||||
if _, ok := f["existing"]; ok {
|
||||
t.Fatalf("empty value should delete the key, got %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIntField(t *testing.T) {
|
||||
f := map[string]any{"existing": 10}
|
||||
|
||||
setIntField(f, "n", 7)
|
||||
if f["n"] != 7 {
|
||||
t.Fatalf("set positive int: got %v", f)
|
||||
}
|
||||
|
||||
setIntField(f, "existing", 0)
|
||||
if _, ok := f["existing"]; ok {
|
||||
t.Fatalf("zero value should delete the key, got %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildVmessLink(t *testing.T) {
|
||||
obj := map[string]any{
|
||||
"v": "2",
|
||||
"ps": "remark",
|
||||
"add": "example.com",
|
||||
"port": 443,
|
||||
"net": "tcp",
|
||||
}
|
||||
link := buildVmessLink(obj)
|
||||
if !strings.HasPrefix(link, "vmess://") {
|
||||
t.Fatalf("missing vmess:// prefix: %q", link)
|
||||
}
|
||||
payload := strings.TrimPrefix(link, "vmess://")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("base64 decode failed: %v", err)
|
||||
}
|
||||
var roundTrip map[string]any
|
||||
if err := json.Unmarshal(decoded, &roundTrip); err != nil {
|
||||
t.Fatalf("decoded payload is not JSON: %v\n%s", err, decoded)
|
||||
}
|
||||
if roundTrip["add"] != "example.com" {
|
||||
t.Fatalf("round-trip add = %v, want example.com", roundTrip["add"])
|
||||
}
|
||||
if roundTrip["ps"] != "remark" {
|
||||
t.Fatalf("round-trip ps = %v, want remark", roundTrip["ps"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneVmessShareObj_CopiesEverythingByDefault(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"v": "2",
|
||||
"sni": "example.com",
|
||||
"alpn": "h2",
|
||||
"fp": "chrome",
|
||||
"net": "tcp",
|
||||
}
|
||||
out := cloneVmessShareObj(base, "tls")
|
||||
for _, key := range []string{"sni", "alpn", "fp", "net", "v"} {
|
||||
if _, ok := out[key]; !ok {
|
||||
t.Fatalf("expected key %q to be preserved when security=tls, got %v", key, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneVmessShareObj_NoneStripsTLSOnlyKeys(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"v": "2",
|
||||
"sni": "example.com",
|
||||
"alpn": "h2",
|
||||
"fp": "chrome",
|
||||
"net": "tcp",
|
||||
}
|
||||
out := cloneVmessShareObj(base, "none")
|
||||
for _, key := range []string{"sni", "alpn", "fp"} {
|
||||
if _, ok := out[key]; ok {
|
||||
t.Fatalf("security=none should strip %q, got %v", key, out)
|
||||
}
|
||||
}
|
||||
if out["v"] != "2" || out["net"] != "tcp" {
|
||||
t.Fatalf("non-TLS keys should remain, got %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractKcpShareFields_Defaults(t *testing.T) {
|
||||
stream := map[string]any{}
|
||||
got := extractKcpShareFields(stream)
|
||||
if got.headerType != "none" {
|
||||
t.Fatalf("default headerType = %q, want none", got.headerType)
|
||||
}
|
||||
if got.seed != "" || got.mtu != 0 || got.tti != 0 {
|
||||
t.Fatalf("default kcpShareFields should be zero except headerType, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractKcpShareFields_ReadsAllFields(t *testing.T) {
|
||||
stream := map[string]any{
|
||||
"kcpSettings": map[string]any{
|
||||
"header": map[string]any{"type": "wechat-video"},
|
||||
"seed": "secret-seed",
|
||||
"mtu": float64(1350),
|
||||
"tti": float64(50),
|
||||
},
|
||||
}
|
||||
got := extractKcpShareFields(stream)
|
||||
if got.headerType != "wechat-video" {
|
||||
t.Fatalf("headerType = %q, want wechat-video", got.headerType)
|
||||
}
|
||||
if got.seed != "secret-seed" {
|
||||
t.Fatalf("seed = %q, want secret-seed", got.seed)
|
||||
}
|
||||
if got.mtu != 1350 {
|
||||
t.Fatalf("mtu = %d, want 1350", got.mtu)
|
||||
}
|
||||
if got.tti != 50 {
|
||||
t.Fatalf("tti = %d, want 50", got.tti)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKcpShareFields_ApplyToParams(t *testing.T) {
|
||||
params := map[string]string{}
|
||||
kcpShareFields{headerType: "wechat-video", seed: "s", mtu: 1350, tti: 50}.applyToParams(params)
|
||||
if params["headerType"] != "wechat-video" {
|
||||
t.Fatalf("headerType param = %q", params["headerType"])
|
||||
}
|
||||
if params["seed"] != "s" {
|
||||
t.Fatalf("seed param = %q", params["seed"])
|
||||
}
|
||||
if params["mtu"] != "1350" {
|
||||
t.Fatalf("mtu param = %q", params["mtu"])
|
||||
}
|
||||
if params["tti"] != "50" {
|
||||
t.Fatalf("tti param = %q", params["tti"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestKcpShareFields_ApplyToParams_NoneHeaderNotAdded(t *testing.T) {
|
||||
params := map[string]string{}
|
||||
kcpShareFields{headerType: "none"}.applyToParams(params)
|
||||
if _, ok := params["headerType"]; ok {
|
||||
t.Fatalf("headerType=none should not be added, got %v", params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalFinalMask_EmptyReturnsFalse(t *testing.T) {
|
||||
if _, ok := marshalFinalMask(map[string]any{}); ok {
|
||||
t.Fatal("expected ok=false for empty finalmask")
|
||||
}
|
||||
if _, ok := marshalFinalMask(nil); ok {
|
||||
t.Fatal("expected ok=false for nil finalmask")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalFinalMask_WithContent(t *testing.T) {
|
||||
fm := map[string]any{
|
||||
"tcp": []any{
|
||||
map[string]any{"type": "fragment"},
|
||||
},
|
||||
}
|
||||
out, ok := marshalFinalMask(fm)
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true for finalmask with valid tcp mask")
|
||||
}
|
||||
if !strings.Contains(out, `"tcp"`) {
|
||||
t.Fatalf("marshaled finalmask missing tcp key: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "fragment") {
|
||||
t.Fatalf("marshaled finalmask missing mask type: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalFinalMask_UnknownTypeIsDropped(t *testing.T) {
|
||||
fm := map[string]any{
|
||||
"tcp": []any{
|
||||
map[string]any{"type": "not-a-real-mask"},
|
||||
},
|
||||
}
|
||||
if _, ok := marshalFinalMask(fm); ok {
|
||||
t.Fatal("unknown mask types should be dropped, leaving nothing to marshal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasFinalMaskContent(t *testing.T) {
|
||||
if hasFinalMaskContent(nil) {
|
||||
t.Fatal("nil should not count as content")
|
||||
}
|
||||
if hasFinalMaskContent(map[string]any{}) {
|
||||
t.Fatal("empty map should not count as content")
|
||||
}
|
||||
if !hasFinalMaskContent(map[string]any{"x": 1}) {
|
||||
t.Fatal("non-empty map should count as content")
|
||||
}
|
||||
}
|
||||
28
util/common/format_test.go
Normal file
28
util/common/format_test.go
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
package common
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFormatTraffic(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
bytes int64
|
||||
want string
|
||||
}{
|
||||
{"zero", 0, "0.00B"},
|
||||
{"under_one_kb", 512, "512.00B"},
|
||||
{"exactly_one_kb", 1024, "1.00KB"},
|
||||
{"one_and_a_half_kb", 1536, "1.50KB"},
|
||||
{"one_mb", 1024 * 1024, "1.00MB"},
|
||||
{"one_gb", 1024 * 1024 * 1024, "1.00GB"},
|
||||
{"one_tb", 1024 * 1024 * 1024 * 1024, "1.00TB"},
|
||||
{"one_pb", 1024 * 1024 * 1024 * 1024 * 1024, "1.00PB"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got := FormatTraffic(c.bytes)
|
||||
if got != c.want {
|
||||
t.Fatalf("FormatTraffic(%d) = %q, want %q", c.bytes, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
44
util/common/multi_error_test.go
Normal file
44
util/common/multi_error_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCombine_AllNilReturnsNil(t *testing.T) {
|
||||
if err := Combine(); err != nil {
|
||||
t.Fatalf("Combine() with no args = %v, want nil", err)
|
||||
}
|
||||
if err := Combine(nil, nil, nil); err != nil {
|
||||
t.Fatalf("Combine(nil, nil, nil) = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombine_SkipsNilErrors(t *testing.T) {
|
||||
e1 := errors.New("boom one")
|
||||
e2 := errors.New("boom two")
|
||||
|
||||
err := Combine(nil, e1, nil, e2, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected non-nil combined error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "boom one") || !strings.Contains(msg, "boom two") {
|
||||
t.Fatalf("combined error %q does not contain both underlying messages", msg)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "multierr: ") {
|
||||
t.Fatalf("combined error %q missing %q prefix", msg, "multierr: ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombine_SingleErrorStillWrapped(t *testing.T) {
|
||||
e := errors.New("only one")
|
||||
err := Combine(e)
|
||||
if err == nil {
|
||||
t.Fatal("expected non-nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "only one") {
|
||||
t.Fatalf("combined error %q missing underlying message", err.Error())
|
||||
}
|
||||
}
|
||||
69
util/crypto/crypto_test.go
Normal file
69
util/crypto/crypto_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashPasswordAsBcrypt_RoundTrip(t *testing.T) {
|
||||
password := "correct horse battery staple"
|
||||
|
||||
hash, err := HashPasswordAsBcrypt(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPasswordAsBcrypt returned error: %v", err)
|
||||
}
|
||||
if hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if hash == password {
|
||||
t.Fatal("hash must not equal the plaintext password")
|
||||
}
|
||||
if !strings.HasPrefix(hash, "$2") {
|
||||
t.Fatalf("expected bcrypt prefix $2..., got %q", hash[:min(4, len(hash))])
|
||||
}
|
||||
|
||||
if !CheckPasswordHash(hash, password) {
|
||||
t.Fatal("CheckPasswordHash returned false for the matching password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPasswordHash_WrongPassword(t *testing.T) {
|
||||
hash, err := HashPasswordAsBcrypt("right-password")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPasswordAsBcrypt returned error: %v", err)
|
||||
}
|
||||
|
||||
if CheckPasswordHash(hash, "wrong-password") {
|
||||
t.Fatal("CheckPasswordHash returned true for a wrong password")
|
||||
}
|
||||
if CheckPasswordHash(hash, "") {
|
||||
t.Fatal("CheckPasswordHash returned true for an empty password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPasswordHash_InvalidHash(t *testing.T) {
|
||||
if CheckPasswordHash("", "anything") {
|
||||
t.Fatal("empty hash must not validate")
|
||||
}
|
||||
if CheckPasswordHash("not-a-bcrypt-hash", "anything") {
|
||||
t.Fatal("malformed hash must not validate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPasswordAsBcrypt_DifferentHashesForSamePassword(t *testing.T) {
|
||||
password := "same-password"
|
||||
h1, err := HashPasswordAsBcrypt(password)
|
||||
if err != nil {
|
||||
t.Fatalf("first hash failed: %v", err)
|
||||
}
|
||||
h2, err := HashPasswordAsBcrypt(password)
|
||||
if err != nil {
|
||||
t.Fatalf("second hash failed: %v", err)
|
||||
}
|
||||
if h1 == h2 {
|
||||
t.Fatal("expected bcrypt to produce different hashes (random salt) for the same password")
|
||||
}
|
||||
if !CheckPasswordHash(h1, password) || !CheckPasswordHash(h2, password) {
|
||||
t.Fatal("both hashes should still validate the original password")
|
||||
}
|
||||
}
|
||||
76
util/json_util/json_test.go
Normal file
76
util/json_util/json_test.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package json_util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRawMessage_MarshalEmptyIsNull(t *testing.T) {
|
||||
var m RawMessage
|
||||
out, err := m.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalJSON on empty returned error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, []byte("null")) {
|
||||
t.Fatalf("empty RawMessage marshaled to %q, want %q", out, "null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawMessage_MarshalPassthrough(t *testing.T) {
|
||||
payload := []byte(`{"a":1}`)
|
||||
m := RawMessage(payload)
|
||||
out, err := m.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalJSON returned error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, payload) {
|
||||
t.Fatalf("MarshalJSON = %q, want %q", out, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawMessage_UnmarshalCopiesData(t *testing.T) {
|
||||
var m RawMessage
|
||||
src := []byte(`{"k":"v"}`)
|
||||
if err := m.UnmarshalJSON(src); err != nil {
|
||||
t.Fatalf("UnmarshalJSON returned error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(m, src) {
|
||||
t.Fatalf("UnmarshalJSON stored %q, want %q", []byte(m), src)
|
||||
}
|
||||
|
||||
src[0] = 'X'
|
||||
if m[0] == 'X' {
|
||||
t.Fatal("UnmarshalJSON kept a reference to the caller's buffer; expected a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawMessage_UnmarshalNilReceiverErrors(t *testing.T) {
|
||||
var m *RawMessage
|
||||
if err := m.UnmarshalJSON([]byte("123")); err == nil {
|
||||
t.Fatal("expected error for nil receiver")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawMessage_RoundTripInsideStruct(t *testing.T) {
|
||||
type wrapper struct {
|
||||
Body RawMessage `json:"body"`
|
||||
}
|
||||
in := wrapper{Body: RawMessage(`{"x":42}`)}
|
||||
encoded, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal returned error: %v", err)
|
||||
}
|
||||
want := `{"body":{"x":42}}`
|
||||
if string(encoded) != want {
|
||||
t.Fatalf("Marshal = %s, want %s", encoded, want)
|
||||
}
|
||||
|
||||
var out wrapper
|
||||
if err := json.Unmarshal(encoded, &out); err != nil {
|
||||
t.Fatalf("json.Unmarshal returned error: %v", err)
|
||||
}
|
||||
if string(out.Body) != `{"x":42}` {
|
||||
t.Fatalf("round-trip Body = %s, want %s", out.Body, `{"x":42}`)
|
||||
}
|
||||
}
|
||||
127
util/netsafe/netsafe_test.go
Normal file
127
util/netsafe/netsafe_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package netsafe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
cases := []struct {
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
{"127.0.0.1", true},
|
||||
{"::1", true},
|
||||
{"10.0.0.5", true},
|
||||
{"172.16.0.1", true},
|
||||
{"192.168.1.1", true},
|
||||
{"169.254.0.1", true},
|
||||
{"0.0.0.0", true},
|
||||
{"::", true},
|
||||
{"8.8.8.8", false},
|
||||
{"1.1.1.1", false},
|
||||
{"2606:4700:4700::1111", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.ip, func(t *testing.T) {
|
||||
ip := net.ParseIP(c.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("could not parse %q", c.ip)
|
||||
}
|
||||
if got := IsBlockedIP(ip); got != c.want {
|
||||
t.Fatalf("IsBlockedIP(%s) = %v, want %v", c.ip, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowPrivateFromContext_Default(t *testing.T) {
|
||||
if AllowPrivateFromContext(context.Background()) {
|
||||
t.Fatal("default context should report AllowPrivate=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowPrivateFromContext_RoundTrip(t *testing.T) {
|
||||
ctx := ContextWithAllowPrivate(context.Background(), true)
|
||||
if !AllowPrivateFromContext(ctx) {
|
||||
t.Fatal("expected AllowPrivate=true after ContextWithAllowPrivate(true)")
|
||||
}
|
||||
ctx = ContextWithAllowPrivate(ctx, false)
|
||||
if AllowPrivateFromContext(ctx) {
|
||||
t.Fatal("expected AllowPrivate=false after overriding with false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeHost_Valid(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"example.com", "example.com"},
|
||||
{" example.com ", "example.com"},
|
||||
{"a.b.c.example.com", "a.b.c.example.com"},
|
||||
{"10.0.0.1", "10.0.0.1"},
|
||||
{"[2606:4700:4700::1111]", "2606:4700:4700::1111"},
|
||||
{"2606:4700:4700::1111", "2606:4700:4700::1111"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.in, func(t *testing.T) {
|
||||
got, err := NormalizeHost(c.in)
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizeHost(%q) returned error: %v", c.in, err)
|
||||
}
|
||||
if !strings.EqualFold(got, c.want) {
|
||||
t.Fatalf("NormalizeHost(%q) = %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeHost_Invalid(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
" ",
|
||||
"-leading-dash.com",
|
||||
"trailing-dash-.com",
|
||||
"bad host with spaces",
|
||||
"under_score.example.com",
|
||||
"exa$mple.com",
|
||||
strings.Repeat("a", 254),
|
||||
}
|
||||
for _, in := range cases {
|
||||
t.Run(in, func(t *testing.T) {
|
||||
if _, err := NormalizeHost(in); err == nil {
|
||||
t.Fatalf("NormalizeHost(%q) expected error, got nil", in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSRFGuardedDialContext_BlocksLiteralPrivateIP(t *testing.T) {
|
||||
_, err := SSRFGuardedDialContext(context.Background(), "tcp", "127.0.0.1:1")
|
||||
if err == nil {
|
||||
t.Fatal("expected dial to 127.0.0.1 to be blocked")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "blocked") {
|
||||
t.Fatalf("expected 'blocked' in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSRFGuardedDialContext_AllowPrivateBypassesGuard(t *testing.T) {
|
||||
ctx := ContextWithAllowPrivate(context.Background(), true)
|
||||
_, err := SSRFGuardedDialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
if err == nil {
|
||||
t.Fatal("dial to a closed loopback port should still fail at the connect step")
|
||||
}
|
||||
if strings.Contains(err.Error(), "blocked private/internal address") {
|
||||
t.Fatalf("expected guard to be bypassed when AllowPrivate=true, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSRFGuardedDialContext_BadAddress(t *testing.T) {
|
||||
if _, err := SSRFGuardedDialContext(context.Background(), "tcp", "no-port"); err == nil {
|
||||
t.Fatal("expected error for address without port")
|
||||
}
|
||||
}
|
||||
63
util/random/random_test.go
Normal file
63
util/random/random_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSeq_LengthAndAlphabet(t *testing.T) {
|
||||
for _, n := range []int{0, 1, 8, 64, 256} {
|
||||
s := Seq(n)
|
||||
if len(s) != n {
|
||||
t.Fatalf("Seq(%d) returned length %d", n, len(s))
|
||||
}
|
||||
for i, r := range s {
|
||||
isDigit := r >= '0' && r <= '9'
|
||||
isLower := r >= 'a' && r <= 'z'
|
||||
isUpper := r >= 'A' && r <= 'Z'
|
||||
if !(isDigit || isLower || isUpper) {
|
||||
t.Fatalf("Seq(%d) byte %d = %q is not alphanumeric", n, i, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeq_NotConstant(t *testing.T) {
|
||||
a := Seq(32)
|
||||
b := Seq(32)
|
||||
if a == b {
|
||||
t.Fatalf("two consecutive Seq(32) calls produced identical output: %q", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNum_InRange(t *testing.T) {
|
||||
for _, upper := range []int{1, 2, 10, 1000} {
|
||||
for i := 0; i < 200; i++ {
|
||||
v := Num(upper)
|
||||
if v < 0 || v >= upper {
|
||||
t.Fatalf("Num(%d) returned %d, out of [0, %d)", upper, v, upper)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64Bytes_DecodesToRequestedSize(t *testing.T) {
|
||||
for _, n := range []int{1, 16, 32, 64} {
|
||||
out := Base64Bytes(n)
|
||||
decoded, err := base64.StdEncoding.DecodeString(out)
|
||||
if err != nil {
|
||||
t.Fatalf("Base64Bytes(%d) produced invalid base64 %q: %v", n, out, err)
|
||||
}
|
||||
if len(decoded) != n {
|
||||
t.Fatalf("Base64Bytes(%d) decoded to %d bytes", n, len(decoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64Bytes_Random(t *testing.T) {
|
||||
a := Base64Bytes(32)
|
||||
b := Base64Bytes(32)
|
||||
if a == b {
|
||||
t.Fatalf("two consecutive Base64Bytes(32) calls produced identical output: %q", a)
|
||||
}
|
||||
}
|
||||
73
web/job/node_traffic_sync_job_test.go
Normal file
73
web/job/node_traffic_sync_job_test.go
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
package job
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAtomicBool_DefaultIsFalse(t *testing.T) {
|
||||
var a atomicBool
|
||||
if a.takeAndReset() {
|
||||
t.Fatal("default atomicBool should report false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicBool_SetThenTakeReturnsTrueOnce(t *testing.T) {
|
||||
var a atomicBool
|
||||
a.set()
|
||||
if !a.takeAndReset() {
|
||||
t.Fatal("takeAndReset after set should return true")
|
||||
}
|
||||
if a.takeAndReset() {
|
||||
t.Fatal("second takeAndReset should return false (state was reset)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicBool_SetIsIdempotent(t *testing.T) {
|
||||
var a atomicBool
|
||||
a.set()
|
||||
a.set()
|
||||
a.set()
|
||||
if !a.takeAndReset() {
|
||||
t.Fatal("repeated set should still leave the flag true")
|
||||
}
|
||||
if a.takeAndReset() {
|
||||
t.Fatal("flag should be cleared after the first take")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicBool_ConcurrentSettersExactlyOneTakeWins(t *testing.T) {
|
||||
var a atomicBool
|
||||
const setters = 100
|
||||
const readers = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < setters; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a.set()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
trueCount := 0
|
||||
var rwg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
for i := 0; i < readers; i++ {
|
||||
rwg.Add(1)
|
||||
go func() {
|
||||
defer rwg.Done()
|
||||
if a.takeAndReset() {
|
||||
mu.Lock()
|
||||
trueCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
rwg.Wait()
|
||||
|
||||
if trueCount != 1 {
|
||||
t.Fatalf("expected exactly one reader to observe true, got %d", trueCount)
|
||||
}
|
||||
}
|
||||
162
web/service/node_test.go
Normal file
162
web/service/node_test.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||
)
|
||||
|
||||
func TestNormalizeBasePath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", "/"},
|
||||
{" ", "/"},
|
||||
{"/", "/"},
|
||||
{"/panel", "/panel/"},
|
||||
{"panel", "/panel/"},
|
||||
{"panel/", "/panel/"},
|
||||
{"/panel/", "/panel/"},
|
||||
{" /panel ", "/panel/"},
|
||||
{"/a/b/c", "/a/b/c/"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.in, func(t *testing.T) {
|
||||
got := normalizeBasePath(c.in)
|
||||
if got != c.want {
|
||||
t.Fatalf("normalizeBasePath(%q) = %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeMetricKey(t *testing.T) {
|
||||
cases := []struct {
|
||||
id int
|
||||
metric string
|
||||
want string
|
||||
}{
|
||||
{1, "cpu", "node:1:cpu"},
|
||||
{42, "mem", "node:42:mem"},
|
||||
{0, "anything", "node:0:anything"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := nodeMetricKey(c.id, c.metric)
|
||||
if got != c.want {
|
||||
t.Fatalf("nodeMetricKey(%d, %q) = %q, want %q", c.id, c.metric, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatPatch_ToUI_OnlineCopiesFields(t *testing.T) {
|
||||
p := HeartbeatPatch{
|
||||
Status: "ignored-source",
|
||||
LatencyMs: 42,
|
||||
XrayVersion: "1.8.4",
|
||||
PanelVersion: "3.0.0",
|
||||
CpuPct: 12.5,
|
||||
MemPct: 33.3,
|
||||
UptimeSecs: 12345,
|
||||
LastError: "",
|
||||
}
|
||||
ui := p.ToUI(true)
|
||||
if ui.Status != "online" {
|
||||
t.Fatalf("Status = %q, want online", ui.Status)
|
||||
}
|
||||
if ui.LatencyMs != 42 || ui.XrayVersion != "1.8.4" || ui.PanelVersion != "3.0.0" {
|
||||
t.Fatalf("scalar copy mismatch: %+v", ui)
|
||||
}
|
||||
if ui.CpuPct != 12.5 || ui.MemPct != 33.3 || ui.UptimeSecs != 12345 {
|
||||
t.Fatalf("metric copy mismatch: %+v", ui)
|
||||
}
|
||||
if ui.Error != "" {
|
||||
t.Fatalf("Error = %q, want empty", ui.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatPatch_ToUI_OfflinePreservesError(t *testing.T) {
|
||||
p := HeartbeatPatch{LastError: "connection refused"}
|
||||
ui := p.ToUI(false)
|
||||
if ui.Status != "offline" {
|
||||
t.Fatalf("Status = %q, want offline", ui.Status)
|
||||
}
|
||||
if ui.Error != "connection refused" {
|
||||
t.Fatalf("Error = %q, want %q", ui.Error, "connection refused")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeService_Normalize_Valid(t *testing.T) {
|
||||
s := &NodeService{}
|
||||
n := &model.Node{
|
||||
Name: " primary ",
|
||||
ApiToken: " abc ",
|
||||
Address: "example.com",
|
||||
Port: 8443,
|
||||
Scheme: "",
|
||||
BasePath: "panel",
|
||||
}
|
||||
if err := s.normalize(n); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if n.Name != "primary" {
|
||||
t.Fatalf("Name not trimmed: %q", n.Name)
|
||||
}
|
||||
if n.ApiToken != "abc" {
|
||||
t.Fatalf("ApiToken not trimmed: %q", n.ApiToken)
|
||||
}
|
||||
if n.Scheme != "https" {
|
||||
t.Fatalf("empty Scheme should default to https, got %q", n.Scheme)
|
||||
}
|
||||
if n.BasePath != "/panel/" {
|
||||
t.Fatalf("BasePath = %q, want /panel/", n.BasePath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeService_Normalize_KeepsValidScheme(t *testing.T) {
|
||||
s := &NodeService{}
|
||||
n := &model.Node{Name: "n", Address: "example.com", Port: 80, Scheme: "http"}
|
||||
if err := s.normalize(n); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if n.Scheme != "http" {
|
||||
t.Fatalf("Scheme = %q, want http", n.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeService_Normalize_RejectsEmptyName(t *testing.T) {
|
||||
s := &NodeService{}
|
||||
n := &model.Node{Name: " ", Address: "example.com", Port: 443}
|
||||
if err := s.normalize(n); err == nil {
|
||||
t.Fatal("expected error for empty name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeService_Normalize_RejectsBadHost(t *testing.T) {
|
||||
s := &NodeService{}
|
||||
n := &model.Node{Name: "n", Address: "bad host name with spaces", Port: 443}
|
||||
if err := s.normalize(n); err == nil {
|
||||
t.Fatal("expected error for invalid host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeService_Normalize_RejectsOutOfRangePort(t *testing.T) {
|
||||
s := &NodeService{}
|
||||
for _, port := range []int{0, -1, 65536, 100000} {
|
||||
n := &model.Node{Name: "n", Address: "example.com", Port: port}
|
||||
if err := s.normalize(n); err == nil {
|
||||
t.Fatalf("expected error for port %d", port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeService_Normalize_OverridesUnknownScheme(t *testing.T) {
|
||||
s := &NodeService{}
|
||||
n := &model.Node{Name: "n", Address: "example.com", Port: 443, Scheme: "ftp"}
|
||||
if err := s.normalize(n); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if n.Scheme != "https" {
|
||||
t.Fatalf("Scheme = %q, want https", n.Scheme)
|
||||
}
|
||||
}
|
||||
257
web/websocket/hub_test.go
Normal file
257
web/websocket/hub_test.go
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
xuilogger "github.com/mhsanaei/3x-ui/v3/logger"
|
||||
"github.com/op/go-logging"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = os.Setenv("XUI_LOG_FOLDER", os.TempDir())
|
||||
xuilogger.InitLogger(logging.ERROR)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestNewClient_HasBufferedSendChannel(t *testing.T) {
|
||||
c := NewClient("client-1")
|
||||
if c.ID != "client-1" {
|
||||
t.Fatalf("ID = %q, want client-1", c.ID)
|
||||
}
|
||||
if cap(c.Send) != clientSendQueue {
|
||||
t.Fatalf("Send cap = %d, want %d", cap(c.Send), clientSendQueue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_NilReceiver_DoesNotPanic(t *testing.T) {
|
||||
var h *Hub
|
||||
if h.GetClientCount() != 0 {
|
||||
t.Fatal("nil hub GetClientCount should return 0")
|
||||
}
|
||||
h.Broadcast(MessageTypeStatus, "anything")
|
||||
h.Register(NewClient("x"))
|
||||
h.Unregister(NewClient("x"))
|
||||
h.Stop()
|
||||
}
|
||||
|
||||
func TestHub_BroadcastDropsWhenNoClients(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
go h.Run()
|
||||
|
||||
h.Broadcast(MessageTypeStatus, "payload")
|
||||
|
||||
select {
|
||||
case <-h.broadcast:
|
||||
t.Fatal("Broadcast should drop when client count is zero")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_BroadcastDropsNilPayload(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
go h.Run()
|
||||
|
||||
c := NewClient("c1")
|
||||
h.Register(c)
|
||||
waitClientCount(t, h, 1)
|
||||
|
||||
h.Broadcast(MessageTypeStatus, nil)
|
||||
|
||||
select {
|
||||
case <-c.Send:
|
||||
t.Fatal("nil payload should be dropped, not delivered")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_BroadcastDeliversToClient(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
go h.Run()
|
||||
|
||||
c := NewClient("c1")
|
||||
h.Register(c)
|
||||
waitClientCount(t, h, 1)
|
||||
|
||||
h.Broadcast(MessageTypeStatus, map[string]string{"k": "v"})
|
||||
|
||||
select {
|
||||
case raw := <-c.Send:
|
||||
var m Message
|
||||
if err := json.Unmarshal(raw, &m); err != nil {
|
||||
t.Fatalf("payload is not valid JSON: %v\n%s", err, raw)
|
||||
}
|
||||
if m.Type != MessageTypeStatus {
|
||||
t.Fatalf("Type = %q, want %q", m.Type, MessageTypeStatus)
|
||||
}
|
||||
if m.Time == 0 {
|
||||
t.Fatal("Time should be set to a non-zero unix-millis value")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("timed out waiting for broadcast to reach client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_UnregisterClosesSendAndDecrementsCount(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
go h.Run()
|
||||
|
||||
c := NewClient("c1")
|
||||
h.Register(c)
|
||||
waitClientCount(t, h, 1)
|
||||
|
||||
h.Unregister(c)
|
||||
waitClientCount(t, h, 0)
|
||||
|
||||
select {
|
||||
case _, ok := <-c.Send:
|
||||
if ok {
|
||||
t.Fatal("expected Send channel to be closed after Unregister")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("Send channel was not closed after Unregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_StopClosesAllClients(t *testing.T) {
|
||||
h := NewHub()
|
||||
go h.Run()
|
||||
|
||||
c1 := NewClient("c1")
|
||||
c2 := NewClient("c2")
|
||||
h.Register(c1)
|
||||
h.Register(c2)
|
||||
waitClientCount(t, h, 2)
|
||||
|
||||
h.Stop()
|
||||
|
||||
for _, c := range []*Client{c1, c2} {
|
||||
select {
|
||||
case _, ok := <-c.Send:
|
||||
if ok {
|
||||
t.Fatalf("client %s Send should be closed after Stop", c.ID)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("client %s Send not closed after Stop", c.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_ShouldThrottle(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
|
||||
if h.shouldThrottle(MessageTypeStatus) {
|
||||
t.Fatal("non-gated message type should never throttle")
|
||||
}
|
||||
if h.shouldThrottle(MessageTypeStatus) {
|
||||
t.Fatal("non-gated message type should never throttle on second call")
|
||||
}
|
||||
|
||||
if h.shouldThrottle(MessageTypeTraffic) {
|
||||
t.Fatal("first call for gated type should not throttle")
|
||||
}
|
||||
if !h.shouldThrottle(MessageTypeTraffic) {
|
||||
t.Fatal("immediate second call for gated type should throttle")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_ShouldThrottle_DistinctTypesIndependent(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
|
||||
if h.shouldThrottle(MessageTypeTraffic) {
|
||||
t.Fatal("first Traffic call should not throttle")
|
||||
}
|
||||
if h.shouldThrottle(MessageTypeInbounds) {
|
||||
t.Fatal("first Inbounds call should not throttle even after Traffic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrySend_SucceedsWithRoom(t *testing.T) {
|
||||
c := &Client{ID: "c", Send: make(chan []byte, 1)}
|
||||
if !trySend(c, []byte("hi")) {
|
||||
t.Fatal("trySend should succeed when buffer has room")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrySend_FailsWhenFull(t *testing.T) {
|
||||
c := &Client{ID: "c", Send: make(chan []byte, 1)}
|
||||
c.Send <- []byte("first")
|
||||
if trySend(c, []byte("second")) {
|
||||
t.Fatal("trySend should fail when buffer is full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrySend_FailsOnClosedChannel(t *testing.T) {
|
||||
c := &Client{ID: "c", Send: make(chan []byte, 1)}
|
||||
close(c.Send)
|
||||
if trySend(c, []byte("after-close")) {
|
||||
t.Fatal("trySend should fail (not panic) when channel is closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_FanoutEvictsSlowClient(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
go h.Run()
|
||||
|
||||
slow := &Client{ID: "slow", Send: make(chan []byte, 1)}
|
||||
slow.Send <- []byte("buffer-already-full")
|
||||
h.Register(slow)
|
||||
waitClientCount(t, h, 1)
|
||||
|
||||
h.Broadcast(MessageTypeStatus, "payload")
|
||||
waitClientCount(t, h, 0)
|
||||
|
||||
select {
|
||||
case _, ok := <-slow.Send:
|
||||
if ok {
|
||||
_, ok = <-slow.Send
|
||||
if ok {
|
||||
t.Fatal("slow client Send should eventually be closed by fanout eviction")
|
||||
}
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("slow client Send channel was not closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHub_ConcurrentRegisterUnregister(t *testing.T) {
|
||||
h := NewHub()
|
||||
defer h.Stop()
|
||||
go h.Run()
|
||||
|
||||
const n = 50
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
c := NewClient("c")
|
||||
h.Register(c)
|
||||
h.Unregister(c)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
waitClientCount(t, h, 0)
|
||||
}
|
||||
|
||||
func waitClientCount(t *testing.T, h *Hub, want int) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if h.GetClientCount() == want {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("client count never reached %d (last seen %d)", want, h.GetClientCount())
|
||||
}
|
||||
82
xray/api_test.go
Normal file
82
xray/api_test.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package xray
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetRequiredUserString_Present(t *testing.T) {
|
||||
user := map[string]any{"email": "alice@example.com"}
|
||||
got, err := getRequiredUserString(user, "email")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "alice@example.com" {
|
||||
t.Fatalf("got %q, want %q", got, "alice@example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequiredUserString_Missing(t *testing.T) {
|
||||
user := map[string]any{}
|
||||
if _, err := getRequiredUserString(user, "email"); err == nil {
|
||||
t.Fatal("expected error for missing key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequiredUserString_NilValue(t *testing.T) {
|
||||
user := map[string]any{"email": nil}
|
||||
if _, err := getRequiredUserString(user, "email"); err == nil {
|
||||
t.Fatal("expected error for nil value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequiredUserString_WrongType(t *testing.T) {
|
||||
user := map[string]any{"email": 42}
|
||||
_, err := getRequiredUserString(user, "email")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-string value")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid type") {
|
||||
t.Fatalf("expected %q in error, got: %v", "invalid type", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOptionalUserString_Present(t *testing.T) {
|
||||
user := map[string]any{"flow": "xtls-rprx-vision"}
|
||||
got, err := getOptionalUserString(user, "flow")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "xtls-rprx-vision" {
|
||||
t.Fatalf("got %q, want %q", got, "xtls-rprx-vision")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOptionalUserString_MissingReturnsEmptyNoError(t *testing.T) {
|
||||
user := map[string]any{}
|
||||
got, err := getOptionalUserString(user, "flow")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for missing optional field: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOptionalUserString_NilReturnsEmptyNoError(t *testing.T) {
|
||||
user := map[string]any{"flow": nil}
|
||||
got, err := getOptionalUserString(user, "flow")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for nil optional field: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOptionalUserString_WrongTypeErrors(t *testing.T) {
|
||||
user := map[string]any{"flow": []string{"a", "b"}}
|
||||
if _, err := getOptionalUserString(user, "flow"); err == nil {
|
||||
t.Fatal("expected error for non-string optional value")
|
||||
}
|
||||
}
|
||||
91
xray/config_test.go
Normal file
91
xray/config_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package xray
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v3/util/json_util"
|
||||
)
|
||||
|
||||
func makeConfig() *Config {
|
||||
return &Config{
|
||||
LogConfig: json_util.RawMessage(`{"loglevel":"warning"}`),
|
||||
RouterConfig: json_util.RawMessage(`{}`),
|
||||
OutboundConfigs: json_util.RawMessage(`[]`),
|
||||
Policy: json_util.RawMessage(`{}`),
|
||||
API: json_util.RawMessage(`{}`),
|
||||
Stats: json_util.RawMessage(`{}`),
|
||||
Metrics: json_util.RawMessage(`{}`),
|
||||
InboundConfigs: []InboundConfig{
|
||||
{
|
||||
Port: 1080,
|
||||
Protocol: "vless",
|
||||
Tag: "inbound-1080",
|
||||
Listen: json_util.RawMessage(`"0.0.0.0"`),
|
||||
Settings: json_util.RawMessage(`{"clients":[]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigEquals_IdenticalConfigs(t *testing.T) {
|
||||
a := makeConfig()
|
||||
b := makeConfig()
|
||||
if !a.Equals(b) {
|
||||
t.Fatal("two identical configs should be Equals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigEquals_DifferentInboundCount(t *testing.T) {
|
||||
a := makeConfig()
|
||||
b := makeConfig()
|
||||
b.InboundConfigs = append(b.InboundConfigs, InboundConfig{Port: 2080, Protocol: "vmess", Tag: "inbound-2080"})
|
||||
if a.Equals(b) {
|
||||
t.Fatal("configs with different inbound counts should not be Equals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigEquals_DifferentInboundContent(t *testing.T) {
|
||||
a := makeConfig()
|
||||
b := makeConfig()
|
||||
b.InboundConfigs[0].Port = 9999
|
||||
if a.Equals(b) {
|
||||
t.Fatal("config with changed inbound port should not be Equals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigEquals_DifferentLogConfig(t *testing.T) {
|
||||
a := makeConfig()
|
||||
b := makeConfig()
|
||||
b.LogConfig = json_util.RawMessage(`{"loglevel":"debug"}`)
|
||||
if a.Equals(b) {
|
||||
t.Fatal("config with changed log section should not be Equals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigEquals_RawSectionsCompared(t *testing.T) {
|
||||
fields := []struct {
|
||||
name string
|
||||
mutator func(c *Config)
|
||||
}{
|
||||
{"RouterConfig", func(c *Config) { c.RouterConfig = json_util.RawMessage(`{"changed":true}`) }},
|
||||
{"DNSConfig", func(c *Config) { c.DNSConfig = json_util.RawMessage(`{"servers":["1.1.1.1"]}`) }},
|
||||
{"OutboundConfigs", func(c *Config) { c.OutboundConfigs = json_util.RawMessage(`[{"tag":"x"}]`) }},
|
||||
{"Transport", func(c *Config) { c.Transport = json_util.RawMessage(`{"x":1}`) }},
|
||||
{"Policy", func(c *Config) { c.Policy = json_util.RawMessage(`{"levels":{}}`) }},
|
||||
{"API", func(c *Config) { c.API = json_util.RawMessage(`{"tag":"api"}`) }},
|
||||
{"Stats", func(c *Config) { c.Stats = json_util.RawMessage(`{"on":true}`) }},
|
||||
{"Reverse", func(c *Config) { c.Reverse = json_util.RawMessage(`{"bridges":[]}`) }},
|
||||
{"FakeDNS", func(c *Config) { c.FakeDNS = json_util.RawMessage(`[]`) }},
|
||||
{"Metrics", func(c *Config) { c.Metrics = json_util.RawMessage(`{"tag":"m"}`) }},
|
||||
}
|
||||
for _, f := range fields {
|
||||
t.Run(f.name, func(t *testing.T) {
|
||||
a := makeConfig()
|
||||
b := makeConfig()
|
||||
f.mutator(b)
|
||||
if a.Equals(b) {
|
||||
t.Fatalf("mutating %s should break Equals", f.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
52
xray/inbound_test.go
Normal file
52
xray/inbound_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package xray
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v3/util/json_util"
|
||||
)
|
||||
|
||||
func makeInbound() InboundConfig {
|
||||
return InboundConfig{
|
||||
Listen: json_util.RawMessage(`"0.0.0.0"`),
|
||||
Port: 1234,
|
||||
Protocol: "vless",
|
||||
Settings: json_util.RawMessage(`{"clients":[{"id":"abc"}]}`),
|
||||
StreamSettings: json_util.RawMessage(`{"network":"tcp"}`),
|
||||
Tag: "inbound-1234",
|
||||
Sniffing: json_util.RawMessage(`{"enabled":false}`),
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfigEquals_Identical(t *testing.T) {
|
||||
a := makeInbound()
|
||||
b := makeInbound()
|
||||
if !a.Equals(&b) {
|
||||
t.Fatal("two identical inbounds should be Equals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundConfigEquals_MutationsBreakEquality(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
mutator func(c *InboundConfig)
|
||||
}{
|
||||
{"Listen", func(c *InboundConfig) { c.Listen = json_util.RawMessage(`"127.0.0.1"`) }},
|
||||
{"Port", func(c *InboundConfig) { c.Port = 9999 }},
|
||||
{"Protocol", func(c *InboundConfig) { c.Protocol = "vmess" }},
|
||||
{"Settings", func(c *InboundConfig) { c.Settings = json_util.RawMessage(`{"clients":[]}`) }},
|
||||
{"StreamSettings", func(c *InboundConfig) { c.StreamSettings = json_util.RawMessage(`{"network":"ws"}`) }},
|
||||
{"Tag", func(c *InboundConfig) { c.Tag = "inbound-other" }},
|
||||
{"Sniffing", func(c *InboundConfig) { c.Sniffing = json_util.RawMessage(`{"enabled":true}`) }},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a := makeInbound()
|
||||
b := makeInbound()
|
||||
tc.mutator(&b)
|
||||
if a.Equals(&b) {
|
||||
t.Fatalf("mutating %s should break Equals", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue