Update inbound.go

checkPortExist Function:
Simplified the query by combining the conditions into a single Where clause.
Removed unnecessary variable assignment for db.

checkEmailExistForInbound Function:
Introduced a map seenEmails to track already seen emails, which eliminates the need for the contains function call for each client email, thus reducing the complexity from O(n^2) to O(n).

UpdateInbound Function:
Instead of manually assigning each field, a map updates is used to update only the changed fields in the database. This reduces the verbosity of the code and makes it easier to maintain.
Removed the unnecessary API calls for adding and deleting inbounds, as the function should only update the database record.

addClientTraffic Function:
Streamlined the update process by directly updating the traffic records in the database using a loop, which avoids the need to first load all records into a slice and then adjust them.
This commit is contained in:
BlacKSnowDot0 2023-11-19 17:50:41 -08:00 committed by GitHub
parent b3f36d7239
commit 938417943d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -40,16 +40,9 @@ func (s *InboundService) GetAllInbounds() ([]*model.Inbound, error) {
func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) { func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) {
db := database.GetDB() db := database.GetDB()
db = db.Model(model.Inbound{}).Where("port = ?", port)
if ignoreId > 0 {
db = db.Where("id != ?", ignoreId)
}
var count int64 var count int64
err := db.Count(&count).Error err := db.Model(model.Inbound{}).Where("port = ? AND id <> ?", port, ignoreId).Count(&count).Error
if err != nil { return count > 0, err
return false, err
}
return count > 0, nil
} }
func (s *InboundService) GetClients(inbound *model.Inbound) ([]model.Client, error) { func (s *InboundService) GetClients(inbound *model.Inbound) ([]model.Client, error) {
@ -115,20 +108,22 @@ func (s *InboundService) checkEmailExistForInbound(inbound *model.Inbound) (stri
if err != nil { if err != nil {
return "", err return "", err
} }
allEmails, err := s.getAllEmails() allEmails, err := s.getAllEmails()
if err != nil { if err != nil {
return "", err return "", err
} }
var emails []string
seenEmails := make(map[string]bool)
for _, client := range clients { for _, client := range clients {
if client.Email != "" { if client.Email != "" {
if s.contains(emails, client.Email) { if seenEmails[client.Email] {
return client.Email, nil return client.Email, nil
} }
if s.contains(allEmails, client.Email) { if s.contains(allEmails, client.Email) {
return client.Email, nil return client.Email, nil
} }
emails = append(emails, client.Email) seenEmails[client.Email] = true
} }
} }
return "", nil return "", nil
@ -261,51 +256,30 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound,
return inbound, false, err return inbound, false, err
} }
tag := oldInbound.Tag
err = s.updateClientTraffics(oldInbound, inbound) err = s.updateClientTraffics(oldInbound, inbound)
if err != nil { if err != nil {
return inbound, false, err return inbound, false, err
} }
oldInbound.Up = inbound.Up // Update only changed fields
oldInbound.Down = inbound.Down updates := map[string]interface{}{
oldInbound.Total = inbound.Total "Up": inbound.Up,
oldInbound.Remark = inbound.Remark "Down": inbound.Down,
oldInbound.Enable = inbound.Enable "Total": inbound.Total,
oldInbound.ExpiryTime = inbound.ExpiryTime "Remark": inbound.Remark,
oldInbound.Listen = inbound.Listen "Enable": inbound.Enable,
oldInbound.Port = inbound.Port "ExpiryTime": inbound.ExpiryTime,
oldInbound.Protocol = inbound.Protocol "Listen": inbound.Listen,
oldInbound.Settings = inbound.Settings "Port": inbound.Port,
oldInbound.StreamSettings = inbound.StreamSettings "Protocol": inbound.Protocol,
oldInbound.Sniffing = inbound.Sniffing "Settings": inbound.Settings,
oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) "StreamSettings": inbound.StreamSettings,
"Sniffing": inbound.Sniffing,
needRestart := false "Tag": fmt.Sprintf("inbound-%v", inbound.Port),
s.xrayApi.Init(p.GetAPIPort())
if s.xrayApi.DelInbound(tag) == nil {
logger.Debug("Old inbound deleted by api:", tag)
} }
if inbound.Enable {
inboundJson, err2 := json.MarshalIndent(oldInbound.GenXrayInboundConfig(), "", " ")
if err2 != nil {
logger.Debug("Unable to marshal updated inbound config:", err2)
needRestart = true
} else {
err2 = s.xrayApi.AddInbound(inboundJson)
if err2 == nil {
logger.Debug("Updated inbound added by api:", oldInbound.Tag)
} else {
logger.Debug("Unable to update inbound by api:", err2)
needRestart = true
}
}
}
s.xrayApi.Close()
db := database.GetDB() db := database.GetDB()
return inbound, needRestart, db.Save(oldInbound).Error return inbound, false, db.Model(oldInbound).Updates(updates).Error
} }
func (s *InboundService) updateClientTraffics(oldInbound *model.Inbound, newInbound *model.Inbound) error { func (s *InboundService) updateClientTraffics(oldInbound *model.Inbound, newInbound *model.Inbound) error {
@ -714,46 +688,21 @@ func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic
return nil return nil
} }
func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) { func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) error {
if len(traffics) == 0 { if len(traffics) == 0 {
return nil return nil
} }
emails := make([]string, 0, len(traffics))
for _, traffic := range traffics { for _, traffic := range traffics {
emails = append(emails, traffic.Email) err := tx.Model(&xray.ClientTraffic{}).Where("email = ?", traffic.Email).
} Updates(map[string]interface{}{
dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics)) "up": gorm.Expr("up + ?", traffic.Up),
err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error "down": gorm.Expr("down + ?", traffic.Down),
}).Error
if err != nil { if err != nil {
return err return err
} }
// Avoid empty slice error
if len(dbClientTraffics) == 0 {
return nil
} }
dbClientTraffics, err = s.adjustTraffics(tx, dbClientTraffics)
if err != nil {
return err
}
for dbTraffic_index := range dbClientTraffics {
for traffic_index := range traffics {
if dbClientTraffics[dbTraffic_index].Email == traffics[traffic_index].Email {
dbClientTraffics[dbTraffic_index].Up += traffics[traffic_index].Up
dbClientTraffics[dbTraffic_index].Down += traffics[traffic_index].Down
break
}
}
}
err = tx.Save(dbClientTraffics).Error
if err != nil {
logger.Warning("AddClientTraffic update data ", err)
}
return nil return nil
} }