diff --git a/web/job/check_client_ip_job.go b/web/job/check_client_ip_job.go index 6439252f..29aa9775 100644 --- a/web/job/check_client_ip_job.go +++ b/web/job/check_client_ip_job.go @@ -246,6 +246,37 @@ func mergeClientIps(old, new []IPWithTimestamp, staleCutoff int64) map[string]in return ipMap } +// partitionLiveIps splits the merged ip map into live (seen in the +// current scan) and historical (only in the db blob, still inside the +// staleness window). +// +// only live ips count toward the per-client limit. historical ones stay +// in the db so the panel keeps showing them, but they must not take a +// protected slot. the 30min cutoff alone isn't tight enough: an ip that +// stopped connecting a few minutes ago still looks fresh to +// mergeClientIps, and since the over-limit picker sorts ascending and +// keeps the oldest, those idle entries used to win the slot while the +// ip actually connecting got classified as excess and sent to fail2ban +// every tick. see #4077 / #4091. +// +// live is sorted ascending so the "protect original, ban newcomer" +// rule still holds when several ips are really connecting at once. +func partitionLiveIps(ipMap map[string]int64, observedThisScan map[string]bool) (live, historical []IPWithTimestamp) { + live = make([]IPWithTimestamp, 0, len(observedThisScan)) + historical = make([]IPWithTimestamp, 0, len(ipMap)) + for ip, ts := range ipMap { + entry := IPWithTimestamp{IP: ip, Timestamp: ts} + if observedThisScan[ip] { + live = append(live, entry) + } else { + historical = append(historical, entry) + } + } + sort.Slice(live, func(i, j int) bool { return live[i].Timestamp < live[j].Timestamp }) + sort.Slice(historical, func(i, j int) bool { return historical[i].Timestamp < historical[j].Timestamp }) + return live, historical +} + func (j *CheckClientIpJob) checkFail2BanInstalled() bool { cmd := "fail2ban-client" args := []string{"-h"} @@ -358,15 +389,13 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun // re-observed in a while. See mergeClientIps / #4077 for why. ipMap := mergeClientIps(oldIpsWithTime, newIpsWithTime, time.Now().Unix()-ipStaleAfterSeconds) - // Convert back to slice and sort by timestamp (oldest first) - // This ensures we always protect the original/current connections and ban new excess ones. - allIps := make([]IPWithTimestamp, 0, len(ipMap)) - for ip, timestamp := range ipMap { - allIps = append(allIps, IPWithTimestamp{IP: ip, Timestamp: timestamp}) + // only ips seen in this scan count toward the limit. see + // partitionLiveIps. + observedThisScan := make(map[string]bool, len(newIpsWithTime)) + for _, ipTime := range newIpsWithTime { + observedThisScan[ipTime.IP] = true } - sort.Slice(allIps, func(i, j int) bool { - return allIps[i].Timestamp < allIps[j].Timestamp // Ascending order (oldest first) - }) + liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan) shouldCleanLog := false j.disAllowedIps = []string{} @@ -381,35 +410,39 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun log.SetOutput(logIpFile) log.SetFlags(log.LstdFlags) - // Check if we exceed the limit - if len(allIps) > limitIp { + // historical db-only ips are excluded from this count on purpose. + var keptLive []IPWithTimestamp + if len(liveIps) > limitIp { shouldCleanLog = true - // Keep the oldest IPs (currently active connections) and ban the new excess ones. - keptIps := allIps[:limitIp] - bannedIps := allIps[limitIp:] + // protect the oldest live ip, ban newcomers. + keptLive = liveIps[:limitIp] + bannedLive := liveIps[limitIp:] - // Log banned IPs in the format fail2ban filters expect: [LIMIT_IP] Email = X || Disconnecting OLD IP = Y || Timestamp = Z - for _, ipTime := range bannedIps { + // log format is load-bearing: x-ui.sh create_iplimit_jails builds + // filter.d/3x-ipl.conf with + // failregex = \[LIMIT_IP\]\s*Email\s*=\s*.+\s*\|\|\s*Disconnecting OLD IP\s*=\s*\s*\|\|\s*Timestamp\s*=\s*\d+ + // don't change the wording. + for _, ipTime := range bannedLive { j.disAllowedIps = append(j.disAllowedIps, ipTime.IP) log.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp) } - // Actually disconnect banned IPs by temporarily removing and re-adding user - // This forces Xray to drop existing connections from banned IPs - if len(bannedIps) > 0 { - j.disconnectClientTemporarily(inbound, clientEmail, clients) - } - - // Update database with only the currently active (kept) IPs - jsonIps, _ := json.Marshal(keptIps) - inboundClientIps.Ips = string(jsonIps) + // force xray to drop existing connections from banned ips + j.disconnectClientTemporarily(inbound, clientEmail, clients) } else { - // Under limit, save all IPs - jsonIps, _ := json.Marshal(allIps) - inboundClientIps.Ips = string(jsonIps) + keptLive = liveIps } + // keep kept-live + historical in the blob so the panel keeps showing + // recently seen ips. banned live ips are already in the fail2ban log + // and will reappear in the next scan if they reconnect. + dbIps := make([]IPWithTimestamp, 0, len(keptLive)+len(historicalIps)) + dbIps = append(dbIps, keptLive...) + dbIps = append(dbIps, historicalIps...) + jsonIps, _ := json.Marshal(dbIps) + inboundClientIps.Ips = string(jsonIps) + db := database.GetDB() err = db.Save(inboundClientIps).Error if err != nil { @@ -418,7 +451,7 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun } if len(j.disAllowedIps) > 0 { - logger.Infof("[LIMIT_IP] Client %s: Kept %d current IPs, queued %d new IPs for fail2ban", clientEmail, limitIp, len(j.disAllowedIps)) + logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d new IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps)) } return shouldCleanLog diff --git a/web/job/check_client_ip_job_integration_test.go b/web/job/check_client_ip_job_integration_test.go new file mode 100644 index 00000000..5a7aa705 --- /dev/null +++ b/web/job/check_client_ip_job_integration_test.go @@ -0,0 +1,250 @@ +package job + +import ( + "encoding/json" + "log" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" + xuilogger "github.com/mhsanaei/3x-ui/v2/logger" + "github.com/op/go-logging" +) + +// 3x-ui logger must be initialised once before any code path that can +// log a warning. otherwise log.Warningf panics on a nil logger. +var loggerInitOnce sync.Once + +// setupIntegrationDB wires a temp sqlite db and log folder so +// updateInboundClientIps can run end to end. closes the db before +// TempDir cleanup so windows doesn't complain about the file being in +// use. +func setupIntegrationDB(t *testing.T) { + t.Helper() + + loggerInitOnce.Do(func() { + xuilogger.InitLogger(logging.ERROR) + }) + + dbDir := t.TempDir() + logDir := t.TempDir() + + t.Setenv("XUI_DB_FOLDER", dbDir) + t.Setenv("XUI_LOG_FOLDER", logDir) + + // updateInboundClientIps calls log.SetOutput on the package global, + // which would leak to other tests in the same binary. + origLogWriter := log.Writer() + origLogFlags := log.Flags() + t.Cleanup(func() { + log.SetOutput(origLogWriter) + log.SetFlags(origLogFlags) + }) + + if err := database.InitDB(filepath.Join(dbDir, "3x-ui.db")); err != nil { + t.Fatalf("database.InitDB failed: %v", err) + } + // LIFO cleanup order: this runs before t.TempDir's own cleanup. + t.Cleanup(func() { + if err := database.CloseDB(); err != nil { + t.Logf("database.CloseDB warning: %v", err) + } + }) +} + +// seed an inbound whose settings json has a single client with the +// given email and ip limit. +func seedInboundWithClient(t *testing.T, tag, email string, limitIp int) { + t.Helper() + settings := map[string]any{ + "clients": []map[string]any{ + { + "email": email, + "limitIp": limitIp, + "enable": true, + }, + }, + } + settingsJSON, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + inbound := &model.Inbound{ + Tag: tag, + Enable: true, + Protocol: model.VLESS, + Port: 4321, + Settings: string(settingsJSON), + } + if err := database.GetDB().Create(inbound).Error; err != nil { + t.Fatalf("seed inbound: %v", err) + } +} + +// seed an InboundClientIps row with the given blob. +func seedClientIps(t *testing.T, email string, ips []IPWithTimestamp) *model.InboundClientIps { + t.Helper() + blob, err := json.Marshal(ips) + if err != nil { + t.Fatalf("marshal ips: %v", err) + } + row := &model.InboundClientIps{ + ClientEmail: email, + Ips: string(blob), + } + if err := database.GetDB().Create(row).Error; err != nil { + t.Fatalf("seed InboundClientIps: %v", err) + } + return row +} + +// read the persisted blob and parse it back. +func readClientIps(t *testing.T, email string) []IPWithTimestamp { + t.Helper() + row := &model.InboundClientIps{} + if err := database.GetDB().Where("client_email = ?", email).First(row).Error; err != nil { + t.Fatalf("read InboundClientIps for %s: %v", email, err) + } + if row.Ips == "" { + return nil + } + var out []IPWithTimestamp + if err := json.Unmarshal([]byte(row.Ips), &out); err != nil { + t.Fatalf("unmarshal Ips blob %q: %v", row.Ips, err) + } + return out +} + +// make a lookup map so asserts don't depend on slice order. +func ipSet(entries []IPWithTimestamp) map[string]int64 { + out := make(map[string]int64, len(entries)) + for _, e := range entries { + out[e.IP] = e.Timestamp + } + return out +} + +// #4091 repro: client has limit=3, db still holds 3 idle ips from a +// few minutes ago, only one live ip is actually connecting. pre-fix: +// live ip got banned every tick and never appeared in the panel. +// post-fix: no ban, live ip persisted, historical ips still visible. +func TestUpdateInboundClientIps_LiveIpNotBannedByStillFreshHistoricals(t *testing.T) { + setupIntegrationDB(t) + + const email = "pr4091-repro" + seedInboundWithClient(t, "inbound-pr4091", email, 3) + + now := time.Now().Unix() + // idle but still within the 30min staleness window. + row := seedClientIps(t, email, []IPWithTimestamp{ + {IP: "10.0.0.1", Timestamp: now - 20*60}, + {IP: "10.0.0.2", Timestamp: now - 15*60}, + {IP: "10.0.0.3", Timestamp: now - 10*60}, + }) + + j := NewCheckClientIpJob() + // the one that's actually connecting (user's 128.71.x.x). + live := []IPWithTimestamp{ + {IP: "128.71.1.1", Timestamp: now}, + } + + shouldCleanLog := j.updateInboundClientIps(row, email, live) + + if shouldCleanLog { + t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3") + } + if len(j.disAllowedIps) != 0 { + t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps) + } + + persisted := ipSet(readClientIps(t, email)) + for _, want := range []string{"128.71.1.1", "10.0.0.1", "10.0.0.2", "10.0.0.3"} { + if _, ok := persisted[want]; !ok { + t.Errorf("expected %s to be persisted in inbound_client_ips.ips; got %v", want, persisted) + } + } + if got := persisted["128.71.1.1"]; got != now { + t.Errorf("live ip timestamp should match the scan timestamp %d, got %d", now, got) + } + + // 3xipl.log must not contain a ban line. + if info, err := os.Stat(readIpLimitLogPath()); err == nil && info.Size() > 0 { + body, _ := os.ReadFile(readIpLimitLogPath()) + t.Fatalf("3xipl.log should be empty when no ips are banned, got:\n%s", body) + } +} + +// opposite invariant: when several ips are actually live and exceed +// the limit, the newcomer still gets banned. +func TestUpdateInboundClientIps_ExcessLiveIpIsStillBanned(t *testing.T) { + setupIntegrationDB(t) + + const email = "pr4091-abuse" + seedInboundWithClient(t, "inbound-pr4091-abuse", email, 1) + + now := time.Now().Unix() + row := seedClientIps(t, email, []IPWithTimestamp{ + {IP: "10.1.0.1", Timestamp: now - 60}, // original connection + }) + + j := NewCheckClientIpJob() + // both live, limit=1. use distinct timestamps so sort-by-timestamp + // is deterministic: 10.1.0.1 is the original (older), 192.0.2.9 + // joined later and must get banned. + live := []IPWithTimestamp{ + {IP: "10.1.0.1", Timestamp: now - 5}, + {IP: "192.0.2.9", Timestamp: now}, + } + + shouldCleanLog := j.updateInboundClientIps(row, email, live) + + if !shouldCleanLog { + t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit") + } + if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "192.0.2.9" { + t.Fatalf("expected 192.0.2.9 to be banned; disAllowedIps = %v", j.disAllowedIps) + } + + persisted := ipSet(readClientIps(t, email)) + if _, ok := persisted["10.1.0.1"]; !ok { + t.Errorf("original IP 10.1.0.1 must still be persisted; got %v", persisted) + } + if _, ok := persisted["192.0.2.9"]; ok { + t.Errorf("banned IP 192.0.2.9 must NOT be persisted; got %v", persisted) + } + + // 3xipl.log must contain the ban line in the exact fail2ban format. + body, err := os.ReadFile(readIpLimitLogPath()) + if err != nil { + t.Fatalf("read 3xipl.log: %v", err) + } + wantSubstr := "[LIMIT_IP] Email = pr4091-abuse || Disconnecting OLD IP = 192.0.2.9" + if !contains(string(body), wantSubstr) { + t.Fatalf("3xipl.log missing expected ban line %q\nfull log:\n%s", wantSubstr, body) + } +} + +// readIpLimitLogPath reads the 3xipl.log path the same way the job +// does via xray.GetIPLimitLogPath but without importing xray here +// just for the path helper (which would pull a lot more deps into the +// test binary). The env-derived log folder is deterministic. +func readIpLimitLogPath() string { + folder := os.Getenv("XUI_LOG_FOLDER") + if folder == "" { + folder = filepath.Join(".", "log") + } + return filepath.Join(folder, "3xipl.log") +} + +func contains(haystack, needle string) bool { + for i := 0; i+len(needle) <= len(haystack); i++ { + if haystack[i:i+len(needle)] == needle { + return true + } + } + return false +} diff --git a/web/job/check_client_ip_job_test.go b/web/job/check_client_ip_job_test.go index 8bd1a73b..fd745f80 100644 --- a/web/job/check_client_ip_job_test.go +++ b/web/job/check_client_ip_job_test.go @@ -75,3 +75,72 @@ func TestMergeClientIps_NoStaleCutoffStillWorks(t *testing.T) { t.Fatalf("zero cutoff should keep everything\ngot: %v\nwant: %v", got, want) } } + +func collectIps(entries []IPWithTimestamp) []string { + out := make([]string, 0, len(entries)) + for _, e := range entries { + out = append(out, e.IP) + } + return out +} + +func TestPartitionLiveIps_SingleLiveNotStarvedByStillFreshHistoricals(t *testing.T) { + // #4091: db holds A, B, C from minutes ago (still in the 30min + // window) but they're not connecting anymore. only D is. old code + // merged all four, sorted ascending, kept [A,B,C] and banned D + // every tick. pin the new rule: only live ips count toward the limit. + ipMap := map[string]int64{ + "A": 1000, + "B": 1100, + "C": 1200, + "D": 2000, + } + observed := map[string]bool{"D": true} + + live, historical := partitionLiveIps(ipMap, observed) + + if got := collectIps(live); !reflect.DeepEqual(got, []string{"D"}) { + t.Fatalf("live set should only contain the ip observed this scan\ngot: %v\nwant: [D]", got) + } + if got := collectIps(historical); !reflect.DeepEqual(got, []string{"A", "B", "C"}) { + t.Fatalf("historical set should contain db-only ips in ascending order\ngot: %v\nwant: [A B C]", got) + } +} + +func TestPartitionLiveIps_ConcurrentLiveIpsStillBanNewcomers(t *testing.T) { + // keep the "protect original, ban newcomer" policy when several ips + // are really live. with limit=1, A must stay and B must be banned. + ipMap := map[string]int64{ + "A": 5000, + "B": 5500, + } + observed := map[string]bool{"A": true, "B": true} + + live, historical := partitionLiveIps(ipMap, observed) + + if got := collectIps(live); !reflect.DeepEqual(got, []string{"A", "B"}) { + t.Fatalf("both live ips should be in the live set, ascending\ngot: %v\nwant: [A B]", got) + } + if len(historical) != 0 { + t.Fatalf("no historical ips expected, got %v", historical) + } +} + +func TestPartitionLiveIps_EmptyScanLeavesDbIntact(t *testing.T) { + // quiet tick: nothing observed => nothing live. everything merged + // is historical. keeps the panel from wiping recent-but-idle ips. + ipMap := map[string]int64{ + "A": 1000, + "B": 1100, + } + observed := map[string]bool{} + + live, historical := partitionLiveIps(ipMap, observed) + + if len(live) != 0 { + t.Fatalf("no live ips expected, got %v", live) + } + if got := collectIps(historical); !reflect.DeepEqual(got, []string{"A", "B"}) { + t.Fatalf("all merged entries should flow to historical\ngot: %v\nwant: [A B]", got) + } +}