diff --git a/web/service/inbound.go b/web/service/inbound.go index b89f030a..da690bfa 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -1789,15 +1789,19 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi } func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (needRestart bool, clientsDisabled bool, err error) { + var disabledNodeIDs []int err = submitTrafficWrite(func() error { var inner error - needRestart, clientsDisabled, inner = s.addTrafficLocked(inboundTraffics, clientTraffics) + needRestart, clientsDisabled, disabledNodeIDs, inner = s.addTrafficLocked(inboundTraffics, clientTraffics) return inner }) + if err == nil && len(disabledNodeIDs) > 0 { + s.restartRemoteNodesOnDisable(disabledNodeIDs) + } return } -func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (bool, bool, error) { +func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (bool, bool, []int, error) { var err error db := database.GetDB() tx := db.Begin() @@ -1811,11 +1815,11 @@ func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clien }() err = s.addInboundTraffic(tx, inboundTraffics) if err != nil { - return false, false, err + return false, false, nil, err } err = s.addClientTraffic(tx, clientTraffics) if err != nil { - return false, false, err + return false, false, nil, err } needRestart0, count, err := s.autoRenewClients(tx) @@ -1826,7 +1830,7 @@ func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clien } disabledClientsCount := int64(0) - needRestart1, count, err := s.disableInvalidClients(tx) + needRestart1, count, disabledNodeIDs, err := s.disableInvalidClients(tx) if err != nil { logger.Warning("Error in disabling invalid clients:", err) } else if count > 0 { @@ -1840,7 +1844,7 @@ func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clien } else if count > 0 { logger.Debugf("%v inbounds disabled", count) } - return needRestart0 || needRestart1 || needRestart2, disabledClientsCount > 0, nil + return needRestart0 || needRestart1 || needRestart2, disabledClientsCount > 0, disabledNodeIDs, nil } func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic) error { @@ -2196,7 +2200,7 @@ func (s *InboundService) disableInvalidInbounds(tx *gorm.DB) (bool, int64, error return needRestart, count, err } -func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) { +func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, []int, error) { now := time.Now().Unix() * 1000 needRestart := false @@ -2205,10 +2209,10 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) Where("((total > 0 AND up + down >= total) OR (expiry_time > 0 AND expiry_time <= ?)) AND enable = ?", now, true). Find(&depletedRows).Error if err != nil { - return false, 0, err + return false, 0, nil, err } if len(depletedRows) == 0 { - return false, 0, nil + return false, 0, nil, nil } depletedEmails := make([]string, 0, len(depletedRows)) @@ -2236,7 +2240,7 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) WHERE clients.email IN ? `, depletedEmails).Scan(&targets).Error if err != nil { - return false, 0, err + return false, 0, nil, err } } @@ -2283,7 +2287,7 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) err = result.Error count := result.RowsAffected if err != nil { - return needRestart, count, err + return needRestart, count, nil, err } if len(depletedEmails) > 0 { @@ -2294,8 +2298,6 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) } } - // Track which remote nodes had clients successfully disabled so we can - // restart their xray and kill the existing active connections. disabledNodeIDs := make(map[int]struct{}) for inboundID, group := range remoteByInbound { emails := make(map[string]struct{}, len(group)) @@ -2314,26 +2316,34 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) } } - // Restart xray on each affected remote node so active connections are - // dropped immediately, not just blocked on the next reconnect attempt. - if len(disabledNodeIDs) > 0 { - restartOnDisable, _ := (&SettingService{}).GetRestartXrayOnClientDisable() - if restartOnDisable { - for nodeID := range disabledNodeIDs { - nodeIDCopy := nodeID - rt, rtErr := runtime.GetManager().RuntimeFor(&nodeIDCopy) - if rtErr != nil { - logger.Warning("disableInvalidClients: get runtime for node", nodeID, "failed:", rtErr) - continue - } - if rtErr = rt.RestartXray(context.Background()); rtErr != nil { - logger.Warning("disableInvalidClients: restart xray on node", nodeID, "failed:", rtErr) - } - } - } + nodeIDs := make([]int, 0, len(disabledNodeIDs)) + for nodeID := range disabledNodeIDs { + nodeIDs = append(nodeIDs, nodeID) } - return needRestart, count, nil + return needRestart, count, nodeIDs, nil +} + +func (s *InboundService) restartRemoteNodesOnDisable(nodeIDs []int) { + restartOnDisable, err := (&SettingService{}).GetRestartXrayOnClientDisable() + if err != nil { + logger.Warning("disableInvalidClients: get RestartXrayOnClientDisable failed:", err) + return + } + if !restartOnDisable { + return + } + for _, nodeID := range nodeIDs { + nodeIDCopy := nodeID + rt, rtErr := runtime.GetManager().RuntimeFor(&nodeIDCopy) + if rtErr != nil { + logger.Warning("disableInvalidClients: get runtime for node", nodeID, "failed:", rtErr) + continue + } + if rtErr = rt.RestartXray(context.Background()); rtErr != nil { + logger.Warning("disableInvalidClients: restart xray on node", nodeID, "failed:", rtErr) + } + } } // markClientsDisabledInSettings flips client.enable=false in the inbound's