diff --git a/web/service/inbound.go b/web/service/inbound.go index a9ac8d0e..a08c46ca 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -40,16 +40,9 @@ func (s *InboundService) GetAllInbounds() ([]*model.Inbound, error) { func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) { db := database.GetDB() - db = db.Model(model.Inbound{}).Where("port = ?", port) - if ignoreId > 0 { - db = db.Where("id != ?", ignoreId) - } var count int64 - err := db.Count(&count).Error - if err != nil { - return false, err - } - return count > 0, nil + err := db.Model(model.Inbound{}).Where("port = ? AND id <> ?", port, ignoreId).Count(&count).Error + return count > 0, err } func (s *InboundService) GetClients(inbound *model.Inbound) ([]model.Client, error) { @@ -115,20 +108,22 @@ func (s *InboundService) checkEmailExistForInbound(inbound *model.Inbound) (stri if err != nil { return "", err } + allEmails, err := s.getAllEmails() if err != nil { return "", err } - var emails []string + + seenEmails := make(map[string]bool) for _, client := range clients { if client.Email != "" { - if s.contains(emails, client.Email) { + if seenEmails[client.Email] { return client.Email, nil } if s.contains(allEmails, client.Email) { return client.Email, nil } - emails = append(emails, client.Email) + seenEmails[client.Email] = true } } return "", nil @@ -261,51 +256,30 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, return inbound, false, err } - tag := oldInbound.Tag - err = s.updateClientTraffics(oldInbound, inbound) if err != nil { return inbound, false, err } - oldInbound.Up = inbound.Up - oldInbound.Down = inbound.Down - oldInbound.Total = inbound.Total - oldInbound.Remark = inbound.Remark - oldInbound.Enable = inbound.Enable - oldInbound.ExpiryTime = inbound.ExpiryTime - oldInbound.Listen = inbound.Listen - oldInbound.Port = inbound.Port - oldInbound.Protocol = inbound.Protocol - oldInbound.Settings = inbound.Settings - oldInbound.StreamSettings = inbound.StreamSettings - oldInbound.Sniffing = inbound.Sniffing - oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) - - needRestart := false - s.xrayApi.Init(p.GetAPIPort()) - if s.xrayApi.DelInbound(tag) == nil { - logger.Debug("Old inbound deleted by api:", tag) + // Update only changed fields + updates := map[string]interface{}{ + "Up": inbound.Up, + "Down": inbound.Down, + "Total": inbound.Total, + "Remark": inbound.Remark, + "Enable": inbound.Enable, + "ExpiryTime": inbound.ExpiryTime, + "Listen": inbound.Listen, + "Port": inbound.Port, + "Protocol": inbound.Protocol, + "Settings": inbound.Settings, + "StreamSettings": inbound.StreamSettings, + "Sniffing": inbound.Sniffing, + "Tag": fmt.Sprintf("inbound-%v", inbound.Port), } - if inbound.Enable { - inboundJson, err2 := json.MarshalIndent(oldInbound.GenXrayInboundConfig(), "", " ") - if err2 != nil { - logger.Debug("Unable to marshal updated inbound config:", err2) - needRestart = true - } else { - err2 = s.xrayApi.AddInbound(inboundJson) - if err2 == nil { - logger.Debug("Updated inbound added by api:", oldInbound.Tag) - } else { - logger.Debug("Unable to update inbound by api:", err2) - needRestart = true - } - } - } - s.xrayApi.Close() db := database.GetDB() - return inbound, needRestart, db.Save(oldInbound).Error + return inbound, false, db.Model(oldInbound).Updates(updates).Error } func (s *InboundService) updateClientTraffics(oldInbound *model.Inbound, newInbound *model.Inbound) error { @@ -714,46 +688,21 @@ func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic return nil } -func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) { +func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) error { if len(traffics) == 0 { return nil } - emails := make([]string, 0, len(traffics)) for _, traffic := range traffics { - emails = append(emails, traffic.Email) - } - dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics)) - err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error - if err != nil { - return err - } - - // Avoid empty slice error - if len(dbClientTraffics) == 0 { - return nil - } - - dbClientTraffics, err = s.adjustTraffics(tx, dbClientTraffics) - if err != nil { - return err - } - - for dbTraffic_index := range dbClientTraffics { - for traffic_index := range traffics { - if dbClientTraffics[dbTraffic_index].Email == traffics[traffic_index].Email { - dbClientTraffics[dbTraffic_index].Up += traffics[traffic_index].Up - dbClientTraffics[dbTraffic_index].Down += traffics[traffic_index].Down - break - } + err := tx.Model(&xray.ClientTraffic{}).Where("email = ?", traffic.Email). + Updates(map[string]interface{}{ + "up": gorm.Expr("up + ?", traffic.Up), + "down": gorm.Expr("down + ?", traffic.Down), + }).Error + if err != nil { + return err } } - - err = tx.Save(dbClientTraffics).Error - if err != nil { - logger.Warning("AddClientTraffic update data ", err) - } - return nil }