diff --git a/web/job/check_device_limit_job.go b/web/job/check_device_limit_job.go index b3ae68a2..c7299321 100644 --- a/web/job/check_device_limit_job.go +++ b/web/job/check_device_limit_job.go @@ -37,6 +37,17 @@ type CheckDeviceLimitJob struct { violationStartTime map[string]time.Time violationMu sync.Mutex running atomic.Bool + + isXrayRunning func() bool + getAPIPort func() int + loadAllInbounds func() ([]*model.Inbound, error) + getClientTraffic func(email string) (*xray.ClientTraffic, error) + getClientByEmail func(email string) (*xray.ClientTraffic, *model.Client, error) + apiInit func(apiPort int) error + apiClose func() + removeUser func(inboundTag, email string) error + addUser func(protocol, inboundTag string, user map[string]any) error + sleep func(time.Duration) } type deviceInboundInfo struct { @@ -47,10 +58,37 @@ type deviceInboundInfo struct { } func NewCheckDeviceLimitJob(xrayService *service.XrayService) *CheckDeviceLimitJob { - return &CheckDeviceLimitJob{ + j := &CheckDeviceLimitJob{ xrayService: xrayService, violationStartTime: make(map[string]time.Time), } + j.isXrayRunning = func() bool { + return j.xrayService != nil && j.xrayService.IsXrayRunning() + } + j.getAPIPort = func() int { + if j.xrayService == nil { + return 0 + } + return j.xrayService.GetAPIPort() + } + j.loadAllInbounds = func() ([]*model.Inbound, error) { + db := database.GetDB() + var inbounds []*model.Inbound + err := db.Find(&inbounds).Error + return inbounds, err + } + j.getClientTraffic = func(email string) (*xray.ClientTraffic, error) { + return j.inboundService.GetClientTrafficByEmail(email) + } + j.getClientByEmail = func(email string) (*xray.ClientTraffic, *model.Client, error) { + return j.inboundService.GetClientByEmail(email) + } + j.apiInit = j.xrayAPI.Init + j.apiClose = j.xrayAPI.Close + j.removeUser = j.xrayAPI.RemoveUser + j.addUser = j.xrayAPI.AddUser + j.sleep = time.Sleep + return j } func (j *CheckDeviceLimitJob) Run() { @@ -60,7 +98,7 @@ func (j *CheckDeviceLimitJob) Run() { } defer j.running.Store(false) - if j.xrayService == nil || !j.xrayService.IsXrayRunning() { + if j.isXrayRunning == nil || !j.isXrayRunning() { return } j.cleanupExpiredIPs() @@ -142,20 +180,22 @@ func (j *CheckDeviceLimitJob) parseAccessLog() { } func (j *CheckDeviceLimitJob) checkAllClientsLimit() { - db := database.GetDB() - var allInbounds []*model.Inbound - if err := db.Find(&allInbounds).Error; err != nil || len(allInbounds) == 0 { + if j.loadAllInbounds == nil { + return + } + allInbounds, err := j.loadAllInbounds() + if err != nil || len(allInbounds) == 0 { return } - apiPort := j.xrayService.GetAPIPort() + apiPort := j.getAPIPort() if apiPort == 0 { return } - if err := j.xrayAPI.Init(apiPort); err != nil { + if err := j.apiInit(apiPort); err != nil { return } - defer j.xrayAPI.Close() + defer j.apiClose() inboundInfoMap := make(map[int]deviceInboundInfo, len(allInbounds)) for _, inbound := range allInbounds { @@ -184,7 +224,7 @@ func (j *CheckDeviceLimitJob) checkAllClientsLimit() { clientStatusMu.RUnlock() for email, activeIPCount := range activeCounts { - traffic, err := j.inboundService.GetClientTrafficByEmail(email) + traffic, err := j.getClientTraffic(email) if err != nil || traffic == nil { continue } @@ -244,7 +284,7 @@ func (j *CheckDeviceLimitJob) checkAllClientsLimit() { if _, online := activeCounts[email]; online { continue } - traffic, err := j.inboundService.GetClientTrafficByEmail(email) + traffic, err := j.getClientTraffic(email) if err != nil || traffic == nil { continue } @@ -267,14 +307,14 @@ func (j *CheckDeviceLimitJob) checkAllClientsLimit() { } func (j *CheckDeviceLimitJob) banUser(email string, activeIPCount int, info deviceInboundInfo) { - _, client, err := j.inboundService.GetClientByEmail(email) + _, client, err := j.getClientByEmail(email) if err != nil || client == nil { return } logger.Infof("[DeviceLimit] banning email=%s limit=%d current=%d", email, info.Limit, activeIPCount) - _ = j.xrayAPI.RemoveUser(info.Tag, email) - time.Sleep(5 * time.Second) + _ = j.removeUser(info.Tag, email) + j.sleep(5 * time.Second) tempClient := *client if tempClient.ID != "" { @@ -288,7 +328,7 @@ func (j *CheckDeviceLimitJob) banUser(email string, activeIPCount int, info devi clientJSON, _ := json.Marshal(tempClient) _ = json.Unmarshal(clientJSON, &clientMap) - if err = j.xrayAPI.AddUser(string(info.Protocol), info.Tag, clientMap); err != nil { + if err = j.addUser(string(info.Protocol), info.Tag, clientMap); err != nil { logger.Warningf("[DeviceLimit] failed to ban user %s: %v", email, err) return } @@ -296,20 +336,20 @@ func (j *CheckDeviceLimitJob) banUser(email string, activeIPCount int, info devi } func (j *CheckDeviceLimitJob) unbanUser(email string, activeIPCount int, info deviceInboundInfo) { - _, client, err := j.inboundService.GetClientByEmail(email) + _, client, err := j.getClientByEmail(email) if err != nil || client == nil { return } logger.Infof("[DeviceLimit] unbanning email=%s limit=%d current=%d", email, info.Limit, activeIPCount) - _ = j.xrayAPI.RemoveUser(info.Tag, email) - time.Sleep(5 * time.Second) + _ = j.removeUser(info.Tag, email) + j.sleep(5 * time.Second) clientMap := map[string]any{} clientJSON, _ := json.Marshal(client) _ = json.Unmarshal(clientJSON, &clientMap) - if err = j.xrayAPI.AddUser(string(info.Protocol), info.Tag, clientMap); err != nil { + if err = j.addUser(string(info.Protocol), info.Tag, clientMap); err != nil { logger.Warningf("[DeviceLimit] failed to restore user %s: %v", email, err) return } diff --git a/web/job/check_device_limit_job_test.go b/web/job/check_device_limit_job_test.go new file mode 100644 index 00000000..61a08efb --- /dev/null +++ b/web/job/check_device_limit_job_test.go @@ -0,0 +1,129 @@ +package job + +import ( + "testing" + "time" + + "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/xray" +) + +func resetDeviceLimitJobGlobals() { + activeClientsLock.Lock() + activeClientIPs = make(map[string]map[string]time.Time) + activeClientsLock.Unlock() + + clientStatusMu.Lock() + clientStatus = make(map[string]bool) + clientStatusMu.Unlock() +} + +func TestCheckDeviceLimitJob_Run_SkipWhenAlreadyRunning(t *testing.T) { + resetDeviceLimitJobGlobals() + + j := NewCheckDeviceLimitJob(nil) + j.running.Store(true) + j.isXrayRunning = func() bool { + t.Fatal("Run should skip execution when already running") + return true + } + + j.Run() +} + +func TestCheckDeviceLimitJob_UnbanWhenEnforcementDisabled(t *testing.T) { + resetDeviceLimitJobGlobals() + + activeClientsLock.Lock() + activeClientIPs["alice@example.com"] = map[string]time.Time{ + "1.2.3.4": time.Now(), + } + activeClientsLock.Unlock() + + clientStatusMu.Lock() + clientStatus["alice@example.com"] = true + clientStatusMu.Unlock() + + j := NewCheckDeviceLimitJob(nil) + j.getAPIPort = func() int { return 10085 } + j.apiInit = func(int) error { return nil } + j.apiClose = func() {} + j.sleep = func(time.Duration) {} + j.loadAllInbounds = func() ([]*model.Inbound, error) { + return []*model.Inbound{ + { + Id: 1, + Enable: false, // Enforcement disabled + DeviceLimit: 0, + Tag: "inbound-10001", + Protocol: model.VLESS, + }, + }, nil + } + j.getClientTraffic = func(email string) (*xray.ClientTraffic, error) { + return &xray.ClientTraffic{InboundId: 1, Email: email}, nil + } + j.getClientByEmail = func(email string) (*xray.ClientTraffic, *model.Client, error) { + return &xray.ClientTraffic{InboundId: 1, Email: email}, &model.Client{ID: "orig-id", Email: email}, nil + } + + removeCalls := 0 + addCalls := 0 + j.removeUser = func(inboundTag, email string) error { + removeCalls++ + return nil + } + j.addUser = func(protocol, inboundTag string, user map[string]any) error { + addCalls++ + return nil + } + + j.checkAllClientsLimit() + + if removeCalls != 1 || addCalls != 1 { + t.Fatalf("expected one restore cycle, got remove=%d add=%d", removeCalls, addCalls) + } + if j.isClientBanned("alice@example.com") { + t.Fatal("expected client ban flag to be cleared when enforcement is disabled") + } +} + +func TestCheckDeviceLimitJob_ClearStaleBanWhenInboundMissing(t *testing.T) { + resetDeviceLimitJobGlobals() + + clientStatusMu.Lock() + clientStatus["ghost@example.com"] = true + clientStatusMu.Unlock() + + j := NewCheckDeviceLimitJob(nil) + j.getAPIPort = func() int { return 10085 } + j.apiInit = func(int) error { return nil } + j.apiClose = func() {} + j.sleep = func(time.Duration) {} + j.loadAllInbounds = func() ([]*model.Inbound, error) { + return []*model.Inbound{ + {Id: 2, Enable: true, DeviceLimit: 1, Tag: "inbound-10002", Protocol: model.VLESS}, + }, nil + } + j.getClientTraffic = func(email string) (*xray.ClientTraffic, error) { + return &xray.ClientTraffic{InboundId: 999, Email: email}, nil + } + j.getClientByEmail = func(email string) (*xray.ClientTraffic, *model.Client, error) { + t.Fatal("GetClientByEmail should not be called when inbound is missing") + return nil, nil, nil + } + j.removeUser = func(inboundTag, email string) error { + t.Fatal("RemoveUser should not be called when inbound is missing") + return nil + } + j.addUser = func(protocol, inboundTag string, user map[string]any) error { + t.Fatal("AddUser should not be called when inbound is missing") + return nil + } + + j.checkAllClientsLimit() + + if j.isClientBanned("ghost@example.com") { + t.Fatal("expected stale banned status to be cleared when inbound no longer exists") + } +}