[db] Enbancement add traffic fully transactional

Co-Authored-By: Alireza Ahmadi <alireza7@gmail.com>
This commit is contained in:
MHSanaei 2023-08-26 15:19:51 +03:30
parent 75df8a05f1
commit 1277285d08
5 changed files with 64 additions and 116 deletions

View file

@ -1,37 +0,0 @@
package job
import (
"x-ui/logger"
"x-ui/web/service"
)
type CheckInboundJob struct {
xrayService service.XrayService
inboundService service.InboundService
}
func NewCheckInboundJob() *CheckInboundJob {
return new(CheckInboundJob)
}
func (j *CheckInboundJob) Run() {
needRestart, count, err := j.inboundService.DisableInvalidClients()
if err != nil {
logger.Warning("Error in disabling invalid clients:", err)
} else if count > 0 {
logger.Debugf("%v clients disabled", count)
if needRestart {
j.xrayService.SetToNeedRestart()
}
}
needRestart, count, err = j.inboundService.DisableInvalidInbounds()
if err != nil {
logger.Warning("Error in disabling invalid inbounds:", err)
} else if count > 0 {
logger.Debugf("%v inbounds disabled", count)
if needRestart {
j.xrayService.SetToNeedRestart()
}
}
}

View file

@ -24,14 +24,12 @@ func (j *XrayTrafficJob) Run() {
logger.Warning("get xray traffic failed:", err) logger.Warning("get xray traffic failed:", err)
return return
} }
err = j.inboundService.AddTraffic(traffics) err, needRestart := j.inboundService.AddTraffic(traffics, clientTraffics)
if err != nil { if err != nil {
logger.Warning("add traffic failed:", err) logger.Warning("add traffic failed:", err)
} }
if needRestart {
err = j.inboundService.AddClientTraffic(clientTraffics) j.xrayService.SetToNeedRestart()
if err != nil {
logger.Warning("add client traffic failed:", err)
} }
} }

View file

@ -194,38 +194,6 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
return inbound, needRestart, err return inbound, needRestart, err
} }
func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error {
for _, inbound := range inbounds {
exist, err := s.checkPortExist(inbound.Port, 0)
if err != nil {
return err
}
if exist {
return common.NewError("Port already exists:", inbound.Port)
}
}
db := database.GetDB()
tx := db.Begin()
var err error
defer func() {
if err == nil {
tx.Commit()
} else {
tx.Rollback()
}
}()
for _, inbound := range inbounds {
err = tx.Save(inbound).Error
if err != nil {
return err
}
}
return nil
}
func (s *InboundService) DelInbound(id int) (bool, error) { func (s *InboundService) DelInbound(id int) (bool, error) {
db := database.GetDB() db := database.GetDB()
@ -687,35 +655,8 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin
return needRestart, tx.Save(oldInbound).Error return needRestart, tx.Save(oldInbound).Error
} }
func (s *InboundService) AddTraffic(traffics []*xray.Traffic) error { func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
if len(traffics) == 0 { var err error
return nil
}
// Update traffics in a single transaction
err := database.GetDB().Transaction(func(tx *gorm.DB) error {
for _, traffic := range traffics {
if traffic.IsInbound {
update := tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
Updates(map[string]interface{}{
"up": gorm.Expr("up + ?", traffic.Up),
"down": gorm.Expr("down + ?", traffic.Down),
})
if update.Error != nil {
return update.Error
}
}
}
return nil
})
return err
}
func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err error) {
if len(traffics) == 0 {
return nil
}
db := database.GetDB() db := database.GetDB()
tx := db.Begin() tx := db.Begin()
@ -726,13 +667,64 @@ func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err e
tx.Commit() tx.Commit()
} }
}() }()
err = s.addInboundTraffic(tx, inboundTraffics)
if err != nil {
return err, false
}
err = s.addClientTraffic(tx, clientTraffics)
if err != nil {
return err, false
}
needRestart1, count, err := s.disableInvalidClients(tx)
if err != nil {
logger.Warning("Error in disabling invalid clients:", err)
} else if count > 0 {
logger.Debugf("%v clients disabled", count)
}
needRestart2, count, err := s.disableInvalidInbounds(tx)
if err != nil {
logger.Warning("Error in disabling invalid inbounds:", err)
} else if count > 0 {
logger.Debugf("%v inbounds disabled", count)
}
return nil, (needRestart1 || needRestart2)
}
func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic) error {
if len(traffics) == 0 {
return nil
}
var err error
for _, traffic := range traffics {
if traffic.IsInbound {
err = tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
Updates(map[string]interface{}{
"up": gorm.Expr("up + ?", traffic.Up),
"down": gorm.Expr("down + ?", traffic.Down),
}).Error
if err != nil {
return err
}
}
}
return nil
}
func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) {
if len(traffics) == 0 {
return nil
}
emails := make([]string, 0, len(traffics)) emails := make([]string, 0, len(traffics))
for _, traffic := range traffics { for _, traffic := range traffics {
emails = append(emails, traffic.Email) emails = append(emails, traffic.Email)
} }
dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics)) dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics))
err = db.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
if err != nil { if err != nil {
return err return err
} }
@ -817,14 +809,13 @@ func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.Cl
return dbClientTraffics, nil return dbClientTraffics, nil
} }
func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) { func (s *InboundService) disableInvalidInbounds(tx *gorm.DB) (bool, int64, error) {
db := database.GetDB()
now := time.Now().Unix() * 1000 now := time.Now().Unix() * 1000
needRestart := false needRestart := false
if p != nil { if p != nil {
var tags []string var tags []string
err := db.Table("inbounds"). err := tx.Table("inbounds").
Select("inbounds.tag"). Select("inbounds.tag").
Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
Scan(&tags).Error Scan(&tags).Error
@ -844,7 +835,7 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
s.xrayApi.Close() s.xrayApi.Close()
} }
result := db.Model(model.Inbound{}). result := tx.Model(model.Inbound{}).
Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
Update("enable", false) Update("enable", false)
err := result.Error err := result.Error
@ -852,8 +843,7 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
return needRestart, count, err return needRestart, count, err
} }
func (s *InboundService) DisableInvalidClients() (bool, int64, error) { func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) {
db := database.GetDB()
now := time.Now().Unix() * 1000 now := time.Now().Unix() * 1000
needRestart := false needRestart := false
@ -863,7 +853,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
Email string Email string
} }
err := db.Table("inbounds"). err := tx.Table("inbounds").
Select("inbounds.tag, client_traffics.email"). Select("inbounds.tag, client_traffics.email").
Joins("JOIN client_traffics ON inbounds.id = client_traffics.inbound_id"). Joins("JOIN client_traffics ON inbounds.id = client_traffics.inbound_id").
Where("((client_traffics.total > 0 AND client_traffics.up + client_traffics.down >= client_traffics.total) OR (client_traffics.expiry_time > 0 AND client_traffics.expiry_time <= ?)) AND client_traffics.enable = ?", now, true). Where("((client_traffics.total > 0 AND client_traffics.up + client_traffics.down >= client_traffics.total) OR (client_traffics.expiry_time > 0 AND client_traffics.expiry_time <= ?)) AND client_traffics.enable = ?", now, true).
@ -883,7 +873,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
} }
s.xrayApi.Close() s.xrayApi.Close()
} }
result := db.Model(xray.ClientTraffic{}). result := tx.Model(xray.ClientTraffic{}).
Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
Update("enable", false) Update("enable", false)
err := result.Error err := result.Error

View file

@ -69,7 +69,7 @@ func (s *XrayService) GetXrayConfig() (*xray.Config, error) {
return nil, err return nil, err
} }
s.inboundService.DisableInvalidClients() s.inboundService.AddTraffic(nil, nil)
inbounds, err := s.inboundService.GetAllInbounds() inbounds, err := s.inboundService.GetAllInbounds()
if err != nil { if err != nil {

View file

@ -247,9 +247,6 @@ func (s *Server) startTask() {
s.cron.AddJob("@every 10s", job.NewXrayTrafficJob()) s.cron.AddJob("@every 10s", job.NewXrayTrafficJob())
}() }()
// Check the inbound traffic every 30 seconds that the traffic exceeds and expires
s.cron.AddJob("@every 30s", job.NewCheckInboundJob())
// check client ips from log file every 10 sec // check client ips from log file every 10 sec
s.cron.AddJob("@every 10s", job.NewCheckClientIpJob()) s.cron.AddJob("@every 10s", job.NewCheckClientIpJob())