From 1277285d0835fdf2f2184f067140e2390fa083bb Mon Sep 17 00:00:00 2001 From: MHSanaei Date: Sat, 26 Aug 2023 15:19:51 +0330 Subject: [PATCH] [db] Enbancement add traffic fully transactional Co-Authored-By: Alireza Ahmadi --- web/job/check_inbound_job.go | 37 ---------- web/job/xray_traffic_job.go | 8 +-- web/service/inbound.go | 130 ++++++++++++++++------------------- web/service/xray.go | 2 +- web/web.go | 3 - 5 files changed, 64 insertions(+), 116 deletions(-) delete mode 100644 web/job/check_inbound_job.go diff --git a/web/job/check_inbound_job.go b/web/job/check_inbound_job.go deleted file mode 100644 index 2044ce68..00000000 --- a/web/job/check_inbound_job.go +++ /dev/null @@ -1,37 +0,0 @@ -package job - -import ( - "x-ui/logger" - "x-ui/web/service" -) - -type CheckInboundJob struct { - xrayService service.XrayService - inboundService service.InboundService -} - -func NewCheckInboundJob() *CheckInboundJob { - return new(CheckInboundJob) -} - -func (j *CheckInboundJob) Run() { - needRestart, count, err := j.inboundService.DisableInvalidClients() - if err != nil { - logger.Warning("Error in disabling invalid clients:", err) - } else if count > 0 { - logger.Debugf("%v clients disabled", count) - if needRestart { - j.xrayService.SetToNeedRestart() - } - } - - needRestart, count, err = j.inboundService.DisableInvalidInbounds() - if err != nil { - logger.Warning("Error in disabling invalid inbounds:", err) - } else if count > 0 { - logger.Debugf("%v inbounds disabled", count) - if needRestart { - j.xrayService.SetToNeedRestart() - } - } -} diff --git a/web/job/xray_traffic_job.go b/web/job/xray_traffic_job.go index 3acdf44a..158930a4 100644 --- a/web/job/xray_traffic_job.go +++ b/web/job/xray_traffic_job.go @@ -24,14 +24,12 @@ func (j *XrayTrafficJob) Run() { logger.Warning("get xray traffic failed:", err) return } - err = j.inboundService.AddTraffic(traffics) + err, needRestart := j.inboundService.AddTraffic(traffics, clientTraffics) if err != nil { logger.Warning("add traffic failed:", err) } - - err = j.inboundService.AddClientTraffic(clientTraffics) - if err != nil { - logger.Warning("add client traffic failed:", err) + if needRestart { + j.xrayService.SetToNeedRestart() } } diff --git a/web/service/inbound.go b/web/service/inbound.go index da7f3c57..a9ac8d0e 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -194,38 +194,6 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo return inbound, needRestart, err } -func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error { - for _, inbound := range inbounds { - exist, err := s.checkPortExist(inbound.Port, 0) - if err != nil { - return err - } - if exist { - return common.NewError("Port already exists:", inbound.Port) - } - } - - db := database.GetDB() - tx := db.Begin() - var err error - defer func() { - if err == nil { - tx.Commit() - } else { - tx.Rollback() - } - }() - - for _, inbound := range inbounds { - err = tx.Save(inbound).Error - if err != nil { - return err - } - } - - return nil -} - func (s *InboundService) DelInbound(id int) (bool, error) { db := database.GetDB() @@ -687,35 +655,8 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin return needRestart, tx.Save(oldInbound).Error } -func (s *InboundService) AddTraffic(traffics []*xray.Traffic) error { - if len(traffics) == 0 { - return nil - } - // Update traffics in a single transaction - err := database.GetDB().Transaction(func(tx *gorm.DB) error { - for _, traffic := range traffics { - if traffic.IsInbound { - update := tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag). - Updates(map[string]interface{}{ - "up": gorm.Expr("up + ?", traffic.Up), - "down": gorm.Expr("down + ?", traffic.Down), - }) - if update.Error != nil { - return update.Error - } - } - } - return nil - }) - - return err -} - -func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err error) { - if len(traffics) == 0 { - return nil - } - +func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) { + var err error db := database.GetDB() tx := db.Begin() @@ -726,13 +667,64 @@ func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err e tx.Commit() } }() + err = s.addInboundTraffic(tx, inboundTraffics) + if err != nil { + return err, false + } + err = s.addClientTraffic(tx, clientTraffics) + if err != nil { + return err, false + } + + needRestart1, count, err := s.disableInvalidClients(tx) + if err != nil { + logger.Warning("Error in disabling invalid clients:", err) + } else if count > 0 { + logger.Debugf("%v clients disabled", count) + } + + needRestart2, count, err := s.disableInvalidInbounds(tx) + if err != nil { + logger.Warning("Error in disabling invalid inbounds:", err) + } else if count > 0 { + logger.Debugf("%v inbounds disabled", count) + } + return nil, (needRestart1 || needRestart2) +} + +func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic) error { + if len(traffics) == 0 { + return nil + } + + var err error + + for _, traffic := range traffics { + if traffic.IsInbound { + err = tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag). + Updates(map[string]interface{}{ + "up": gorm.Expr("up + ?", traffic.Up), + "down": gorm.Expr("down + ?", traffic.Down), + }).Error + if err != nil { + return err + } + } + } + return nil +} + +func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) { + if len(traffics) == 0 { + return nil + } emails := make([]string, 0, len(traffics)) for _, traffic := range traffics { emails = append(emails, traffic.Email) } dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics)) - err = db.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error + err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error if err != nil { return err } @@ -817,14 +809,13 @@ func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.Cl return dbClientTraffics, nil } -func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) { - db := database.GetDB() +func (s *InboundService) disableInvalidInbounds(tx *gorm.DB) (bool, int64, error) { now := time.Now().Unix() * 1000 needRestart := false if p != nil { var tags []string - err := db.Table("inbounds"). + err := tx.Table("inbounds"). Select("inbounds.tag"). Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). Scan(&tags).Error @@ -844,7 +835,7 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) { s.xrayApi.Close() } - result := db.Model(model.Inbound{}). + result := tx.Model(model.Inbound{}). Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). Update("enable", false) err := result.Error @@ -852,8 +843,7 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) { return needRestart, count, err } -func (s *InboundService) DisableInvalidClients() (bool, int64, error) { - db := database.GetDB() +func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) { now := time.Now().Unix() * 1000 needRestart := false @@ -863,7 +853,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) { Email string } - err := db.Table("inbounds"). + err := tx.Table("inbounds"). Select("inbounds.tag, client_traffics.email"). Joins("JOIN client_traffics ON inbounds.id = client_traffics.inbound_id"). Where("((client_traffics.total > 0 AND client_traffics.up + client_traffics.down >= client_traffics.total) OR (client_traffics.expiry_time > 0 AND client_traffics.expiry_time <= ?)) AND client_traffics.enable = ?", now, true). @@ -883,7 +873,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) { } s.xrayApi.Close() } - result := db.Model(xray.ClientTraffic{}). + result := tx.Model(xray.ClientTraffic{}). Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). Update("enable", false) err := result.Error diff --git a/web/service/xray.go b/web/service/xray.go index 4a9bf7b2..2dbc092d 100644 --- a/web/service/xray.go +++ b/web/service/xray.go @@ -69,7 +69,7 @@ func (s *XrayService) GetXrayConfig() (*xray.Config, error) { return nil, err } - s.inboundService.DisableInvalidClients() + s.inboundService.AddTraffic(nil, nil) inbounds, err := s.inboundService.GetAllInbounds() if err != nil { diff --git a/web/web.go b/web/web.go index 78f9e6f4..c3a88b66 100644 --- a/web/web.go +++ b/web/web.go @@ -247,9 +247,6 @@ func (s *Server) startTask() { s.cron.AddJob("@every 10s", job.NewXrayTrafficJob()) }() - // Check the inbound traffic every 30 seconds that the traffic exceeds and expires - s.cron.AddJob("@every 30s", job.NewCheckInboundJob()) - // check client ips from log file every 10 sec s.cron.AddJob("@every 10s", job.NewCheckClientIpJob())