diff --git a/web/service/inbound.go b/web/service/inbound.go index 461efb99..669827b4 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -548,7 +548,7 @@ func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inb } } if !emailExists { - err = s.DelClientStat(tx, oldClient.Email) + err = s.DelClientStat(tx, oldInbound.Id, oldClient.Email) if err != nil { return err } @@ -754,12 +754,14 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, if len(email) > 0 { notDepleted := true - err = db.Model(xray.ClientTraffic{}).Select("enable").Where("email = ?", email).First(¬Depleted).Error - if err != nil { + var traffic xray.ClientTraffic + err = db.Model(xray.ClientTraffic{}).Where("inbound_id = ? AND email = ?", inboundId, email).First(&traffic).Error + if err != nil && err != gorm.ErrRecordNotFound { logger.Error("Get stats error") return false, err } - err = s.DelClientStat(db, email) + notDepleted = err == nil && traffic.Enable + err = s.DelClientStat(db, inboundId, email) if err != nil { logger.Error("Delete stats Data Error") return false, err @@ -905,7 +907,7 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin if len(clients[0].Email) > 0 { if len(oldEmail) > 0 { - err = s.UpdateClientStat(tx, oldEmail, &clients[0]) + err = s.UpdateClientStat(tx, data.Id, oldEmail, &clients[0]) if err != nil { return false, err } @@ -917,7 +919,7 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin s.AddClientStat(tx, data.Id, &clients[0]) } } else { - err = s.DelClientStat(tx, oldEmail) + err = s.DelClientStat(tx, data.Id, oldEmail) if err != nil { return false, err } @@ -1368,9 +1370,9 @@ func (s *InboundService) AddClientStat(tx *gorm.DB, inboundId int, client *model return err } -func (s *InboundService) UpdateClientStat(tx *gorm.DB, email string, client *model.Client) error { +func (s *InboundService) UpdateClientStat(tx *gorm.DB, inboundId int, email string, client *model.Client) error { result := tx.Model(xray.ClientTraffic{}). - Where("email = ?", email). + Where("inbound_id = ? AND email = ?", inboundId, email). Updates(map[string]any{ "enable": client.Enable, "email": client.Email, @@ -1386,8 +1388,8 @@ func (s *InboundService) UpdateClientIPs(tx *gorm.DB, oldEmail string, newEmail return tx.Model(model.InboundClientIps{}).Where("client_email = ?", oldEmail).Update("client_email", newEmail).Error } -func (s *InboundService) DelClientStat(tx *gorm.DB, email string) error { - return tx.Where("email = ?", email).Delete(xray.ClientTraffic{}).Error +func (s *InboundService) DelClientStat(tx *gorm.DB, inboundId int, email string) error { + return tx.Where("inbound_id = ? AND email = ?", inboundId, email).Delete(xray.ClientTraffic{}).Error } func (s *InboundService) DelClientIPs(tx *gorm.DB, email string) error { @@ -2070,6 +2072,23 @@ func (s *InboundService) GetClientTrafficByEmail(email string) (traffic *xray.Cl return nil, nil } +func (s *InboundService) GetClientTrafficByInboundAndEmail(inboundId int, email string) (*xray.ClientTraffic, error) { + db := database.GetDB() + traffic := &xray.ClientTraffic{} + + err := db.Model(xray.ClientTraffic{}). + Where("inbound_id = ? AND email = ?", inboundId, email). + First(traffic).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } + logger.Warningf("Error retrieving ClientTraffic with inbound %d and email %s: %v", inboundId, email, err) + return nil, err + } + return traffic, nil +} + func (s *InboundService) UpdateClientTrafficByEmail(email string, upload int64, download int64) error { db := database.GetDB() @@ -2330,7 +2349,9 @@ func (s *InboundService) MigrationRequirements() { for _, modelClient := range modelClients { if len(modelClient.Email) > 0 { var count int64 - tx.Model(xray.ClientTraffic{}).Where("email = ?", modelClient.Email).Count(&count) + tx.Model(xray.ClientTraffic{}). + Where("inbound_id = ? AND email = ?", inbounds[inbound_index].Id, modelClient.Email). + Count(&count) if count == 0 { s.AddClientStat(tx, inbounds[inbound_index].Id, &modelClient) } @@ -2508,12 +2529,12 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b // remove stats too if len(email) > 0 { - traffic, err := s.GetClientTrafficByEmail(email) + traffic, err := s.GetClientTrafficByInboundAndEmail(inboundId, email) if err != nil { return false, err } if traffic != nil { - if err := s.DelClientStat(db, email); err != nil { + if err := s.DelClientStat(db, inboundId, email); err != nil { logger.Error("Delete stats Data Error") return false, err } diff --git a/web/service/inbound_test.go b/web/service/inbound_test.go new file mode 100644 index 00000000..b78a0f7a --- /dev/null +++ b/web/service/inbound_test.go @@ -0,0 +1,196 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/xray" +) + +func mustMarshalInboundSettings(t *testing.T, clients ...model.Client) string { + t.Helper() + + settings := map[string]any{ + "clients": clients, + } + data, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal inbound settings failed: %v", err) + } + return string(data) +} + +func mustCreateInboundWithClients(t *testing.T, svc *InboundService, inbound model.Inbound, clients ...model.Client) *model.Inbound { + t.Helper() + + inbound.Settings = mustMarshalInboundSettings(t, clients...) + if err := database.GetDB().Create(&inbound).Error; err != nil { + t.Fatalf("create inbound failed: %v", err) + } + + for i := range clients { + if clients[i].Email == "" { + continue + } + if err := svc.AddClientStat(database.GetDB(), inbound.Id, &clients[i]); err != nil { + t.Fatalf("create client traffic failed: %v", err) + } + } + + return &inbound +} + +func countClientTraffic(t *testing.T, inboundID int, email string) int64 { + t.Helper() + + var count int64 + query := database.GetDB().Model(&xray.ClientTraffic{}).Where("email = ?", email) + if inboundID > 0 { + query = query.Where("inbound_id = ?", inboundID) + } + if err := query.Count(&count).Error; err != nil { + t.Fatalf("count client traffic failed: %v", err) + } + return count +} + +func TestDelInboundClientByEmail_ScopedToInbound(t *testing.T) { + setupTestDB(t) + + svc := &InboundService{} + duplicateEmail := "shared@example.com" + + inbound1 := mustCreateInboundWithClients(t, svc, model.Inbound{ + UserId: 1, + Port: 10001, + Protocol: model.VLESS, + Tag: "inbound-test-1", + }, model.Client{ + ID: "client-1", + Email: duplicateEmail, + Enable: false, + }, model.Client{ + ID: "client-2", + Email: "unique-1@example.com", + Enable: false, + }) + + inbound2 := mustCreateInboundWithClients(t, svc, model.Inbound{ + UserId: 1, + Port: 10002, + Protocol: model.VLESS, + Tag: "inbound-test-2", + }, model.Client{ + ID: "client-3", + Email: duplicateEmail, + Enable: false, + }, model.Client{ + ID: "client-4", + Email: "unique-2@example.com", + Enable: false, + }) + + if got := countClientTraffic(t, 0, duplicateEmail); got != 2 { + t.Fatalf("expected 2 traffic rows before deletion, got %d", got) + } + + if _, err := svc.DelInboundClientByEmail(inbound1.Id, duplicateEmail); err != nil { + t.Fatalf("first delete failed: %v", err) + } + if got := countClientTraffic(t, inbound1.Id, duplicateEmail); got != 0 { + t.Fatalf("expected inbound1 traffic to be deleted, got %d", got) + } + if got := countClientTraffic(t, inbound2.Id, duplicateEmail); got != 1 { + t.Fatalf("expected inbound2 traffic to remain, got %d", got) + } + + if _, err := svc.DelInboundClientByEmail(inbound2.Id, duplicateEmail); err != nil { + t.Fatalf("second delete failed: %v", err) + } + if got := countClientTraffic(t, 0, duplicateEmail); got != 0 { + t.Fatalf("expected all duplicate-email traffics to be deleted, got %d", got) + } +} + +func TestUpdateInboundClient_DoesNotUpdateOtherInboundTraffic(t *testing.T) { + setupTestDB(t) + + p = xray.NewProcess(&xray.Config{}) + svc := &InboundService{} + duplicateEmail := "shared@example.com" + renamedEmail := "renamed@example.com" + + inbound1 := mustCreateInboundWithClients(t, svc, model.Inbound{ + UserId: 1, + Port: 11001, + Protocol: model.VLESS, + Tag: "inbound-edit-1", + }, model.Client{ + ID: "client-1", + Email: duplicateEmail, + Enable: false, + TotalGB: 10, + ExpiryTime: 111, + }, model.Client{ + ID: "client-2", + Email: "unique-1@example.com", + Enable: false, + }) + + inbound2 := mustCreateInboundWithClients(t, svc, model.Inbound{ + UserId: 1, + Port: 11002, + Protocol: model.VLESS, + Tag: "inbound-edit-2", + }, model.Client{ + ID: "client-3", + Email: duplicateEmail, + Enable: false, + TotalGB: 20, + ExpiryTime: 222, + }, model.Client{ + ID: "client-4", + Email: "unique-2@example.com", + Enable: false, + }) + + updatePayload := &model.Inbound{ + Id: inbound1.Id, + Settings: mustMarshalInboundSettings(t, model.Client{ + ID: "client-1", + Email: renamedEmail, + Enable: false, + TotalGB: 30, + ExpiryTime: 333, + }), + } + + if _, err := svc.UpdateInboundClient(updatePayload, "client-1"); err != nil { + t.Fatalf("update inbound client failed: %v", err) + } + + var inbound1Traffic xray.ClientTraffic + if err := database.GetDB(). + Where("inbound_id = ? AND email = ?", inbound1.Id, renamedEmail). + First(&inbound1Traffic).Error; err != nil { + t.Fatalf("expected updated inbound1 traffic row: %v", err) + } + if inbound1Traffic.Total != 30 || inbound1Traffic.ExpiryTime != 333 { + t.Fatalf("unexpected inbound1 traffic values: %+v", inbound1Traffic) + } + + var inbound2Traffic xray.ClientTraffic + if err := database.GetDB(). + Where("inbound_id = ? AND email = ?", inbound2.Id, duplicateEmail). + First(&inbound2Traffic).Error; err != nil { + t.Fatalf("expected inbound2 traffic row to remain unchanged: %v", err) + } + if inbound2Traffic.Total != 20 || inbound2Traffic.ExpiryTime != 222 { + t.Fatalf("unexpected inbound2 traffic values: %+v", inbound2Traffic) + } + if got := countClientTraffic(t, inbound2.Id, renamedEmail); got != 0 { + t.Fatalf("expected renamed email to stay isolated to inbound1, got %d rows in inbound2", got) + } +}