diff --git a/web/service/inbound.go b/web/service/inbound.go index 68c1cc42..bfa52c3d 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -26,6 +26,14 @@ type InboundService struct { xrayApi xray.XrayAPI } +func ensureSharedWriteAllowed() error { + return RequireMaster() +} + +func bumpSharedVersion(tx *gorm.DB) error { + return database.BumpSharedAccountsVersion(tx) +} + // GetInbounds retrieves all inbounds for a specific user. // Returns a slice of inbound models with their associated client statistics. func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) { @@ -228,6 +236,10 @@ func (s *InboundService) checkEmailExistInInbound(inbound *model.Inbound, email // then saves the inbound to the database and optionally adds it to the running Xray instance. // Returns the created inbound, whether Xray needs restart, and any error. func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return inbound, false, err + } + exist, err := s.checkPortExist(inbound.Listen, inbound.Port, 0) if err != nil { return inbound, false, err @@ -311,6 +323,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo } else { return inbound, false, err } + err = bumpSharedVersion(tx) + if err != nil { + return inbound, false, err + } needRestart := false if inbound.Enable { @@ -337,6 +353,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo // It removes the inbound from the database and the running Xray instance if active. // Returns whether Xray needs restart and any error. func (s *InboundService) DelInbound(id int) (bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return false, err + } + db := database.GetDB() var tag string @@ -356,11 +376,6 @@ func (s *InboundService) DelInbound(id int) (bool, error) { logger.Debug("No enabled inbound founded to removing by api", tag) } - // Delete client traffics of inbounds - err := db.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error - if err != nil { - return false, err - } inbound, err := s.GetInbound(id) if err != nil { return false, err @@ -369,14 +384,35 @@ func (s *InboundService) DelInbound(id int) (bool, error) { if err != nil { return false, err } + + tx := db.Begin() + defer func() { + if err != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() + + err = tx.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error + if err != nil { + return false, err + } for _, client := range clients { - err := s.DelClientIPs(db, client.Email) + err := s.DelClientIPs(tx, client.Email) if err != nil { return false, err } } - - return needRestart, db.Delete(model.Inbound{}, id).Error + err = tx.Delete(model.Inbound{}, id).Error + if err != nil { + return false, err + } + err = bumpSharedVersion(tx) + if err != nil { + return false, err + } + return needRestart, nil } func (s *InboundService) GetInbound(id int) (*model.Inbound, error) { @@ -472,6 +508,10 @@ func (s *InboundService) DelInboundClientByEmailForUser(userID int, isAdmin bool // It validates changes, updates the database, and syncs with the running Xray instance. // Returns the updated inbound, whether Xray needs restart, and any error. func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return inbound, false, err + } + exist, err := s.checkPortExist(inbound.Listen, inbound.Port, inbound.Id) if err != nil { return inbound, false, err @@ -604,7 +644,15 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, } s.xrayApi.Close() - return inbound, needRestart, tx.Save(oldInbound).Error + err = tx.Save(oldInbound).Error + if err != nil { + return inbound, false, err + } + err = bumpSharedVersion(tx) + if err != nil { + return inbound, false, err + } + return inbound, needRestart, nil } func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inbound, newInbound *model.Inbound) error { @@ -653,6 +701,10 @@ func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inb } func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return false, err + } + clients, err := s.GetClients(data) if err != nil { return false, err @@ -773,10 +825,22 @@ func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) { } s.xrayApi.Close() - return needRestart, tx.Save(oldInbound).Error + err = tx.Save(oldInbound).Error + if err != nil { + return false, err + } + err = bumpSharedVersion(tx) + if err != nil { + return false, err + } + return needRestart, nil } func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return false, err + } + oldInbound, err := s.GetInbound(inboundId) if err != nil { logger.Error("Load Old Data Error") @@ -824,8 +888,16 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, oldInbound.Settings = string(newSettings) db := database.GetDB() + tx := db.Begin() + defer func() { + if err != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() - err = s.DelClientIPs(db, email) + err = s.DelClientIPs(tx, email) if err != nil { logger.Error("Error in delete client IPs") return false, err @@ -841,7 +913,7 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, return false, err } notDepleted = err == nil && traffic.Enable - err = s.DelClientStat(db, inboundId, email) + err = s.DelClientStat(tx, inboundId, email) if err != nil { logger.Error("Delete stats Data Error") return false, err @@ -863,10 +935,22 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, s.xrayApi.Close() } } - return needRestart, db.Save(oldInbound).Error + err = tx.Save(oldInbound).Error + if err != nil { + return false, err + } + err = bumpSharedVersion(tx) + if err != nil { + return false, err + } + return needRestart, nil } func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId string) (bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return false, err + } + // TODO: check if TrafficReset field is updating clients, err := s.GetClients(data) if err != nil { @@ -1049,7 +1133,15 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin logger.Debug("Client old email not found") needRestart = true } - return needRestart, tx.Save(oldInbound).Error + err = tx.Save(oldInbound).Error + if err != nil { + return false, err + } + err = bumpSharedVersion(tx) + if err != nil { + return false, err + } + return needRestart, nil } func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) { @@ -1895,6 +1987,10 @@ func (s *InboundService) ResetClientTrafficByEmail(clientEmail string) error { } func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return false, err + } + needRestart := false traffic, err := s.GetClientTrafficByEmail(clientEmail) @@ -1948,7 +2044,20 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e traffic.Enable = true db := database.GetDB() - err = db.Save(traffic).Error + tx := db.Begin() + defer func() { + if err != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() + + err = tx.Save(traffic).Error + if err != nil { + return false, err + } + err = bumpSharedVersion(tx) if err != nil { return false, err } @@ -1957,6 +2066,10 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e } func (s *InboundService) ResetAllClientTraffics(id int) error { + if err := ensureSharedWriteAllowed(); err != nil { + return err + } + db := database.GetDB() now := time.Now().Unix() * 1000 @@ -1988,23 +2101,35 @@ func (s *InboundService) ResetAllClientTraffics(id int) error { result = tx.Model(model.Inbound{}). Where(inboundWhereText, id). Update("last_traffic_reset_time", now) - - return result.Error + if result.Error != nil { + return result.Error + } + return bumpSharedVersion(tx) }) } func (s *InboundService) ResetAllTraffics() error { + if err := ensureSharedWriteAllowed(); err != nil { + return err + } + db := database.GetDB() - - result := db.Model(model.Inbound{}). - Where("user_id > ?", 0). - Updates(map[string]any{"up": 0, "down": 0}) - - err := result.Error - return err + return db.Transaction(func(tx *gorm.DB) error { + result := tx.Model(model.Inbound{}). + Where("user_id > ?", 0). + Updates(map[string]any{"up": 0, "down": 0}) + if result.Error != nil { + return result.Error + } + return bumpSharedVersion(tx) + }) } func (s *InboundService) DelDepletedClients(id int) (err error) { + if err := ensureSharedWriteAllowed(); err != nil { + return err + } + db := database.GetDB() tx := db.Begin() defer func() { @@ -2085,6 +2210,12 @@ func (s *InboundService) DelDepletedClients(id int) (err error) { if err != nil { return err } + if len(depletedClients) > 0 { + err = bumpSharedVersion(tx) + if err != nil { + return err + } + } return nil } @@ -2551,6 +2682,10 @@ func (s *InboundService) FilterAndSortClientEmails(emails []string) ([]string, [ return validEmails, extraEmails, nil } func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (bool, error) { + if err := ensureSharedWriteAllowed(); err != nil { + return false, err + } + oldInbound, err := s.GetInbound(inboundId) if err != nil { logger.Error("Load Old Data Error") @@ -2601,9 +2736,17 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b oldInbound.Settings = string(newSettings) db := database.GetDB() + tx := db.Begin() + defer func() { + if err != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() // remove IP bindings - if err := s.DelClientIPs(db, email); err != nil { + if err := s.DelClientIPs(tx, email); err != nil { logger.Error("Error in delete client IPs") return false, err } @@ -2617,7 +2760,7 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b return false, err } if traffic != nil { - if err := s.DelClientStat(db, inboundId, email); err != nil { + if err := s.DelClientStat(tx, inboundId, email); err != nil { logger.Error("Delete stats Data Error") return false, err } @@ -2640,5 +2783,13 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b } } - return needRestart, db.Save(oldInbound).Error + err = tx.Save(oldInbound).Error + if err != nil { + return false, err + } + err = bumpSharedVersion(tx) + if err != nil { + return false, err + } + return needRestart, nil } diff --git a/web/service/node_guard.go b/web/service/node_guard.go new file mode 100644 index 00000000..37db35a4 --- /dev/null +++ b/web/service/node_guard.go @@ -0,0 +1,28 @@ +package service + +import ( + "errors" + + "github.com/mhsanaei/3x-ui/v2/config" +) + +var ErrSharedWriteRequiresMaster = errors.New("shared-account writes are only allowed on master nodes") + +func IsWorker() bool { + return config.GetNodeConfigFromJSON().Role == config.NodeRoleWorker +} + +func IsMaster() bool { + return !IsWorker() +} + +func RequireMaster() error { + if IsWorker() { + return ErrSharedWriteRequiresMaster + } + return nil +} + +func IsSharedModeEnabled() bool { + return config.GetDBConfigFromJSON().Type == "mariadb" +} diff --git a/web/service/node_guard_test.go b/web/service/node_guard_test.go new file mode 100644 index 00000000..dee2381d --- /dev/null +++ b/web/service/node_guard_test.go @@ -0,0 +1,64 @@ +package service + +import ( + "encoding/json" + "os" + "testing" + + "github.com/mhsanaei/3x-ui/v2/config" + "github.com/mhsanaei/3x-ui/v2/database" +) + +func writeNodeGuardSettings(t *testing.T, settings map[string]any) { + t.Helper() + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + t.Fatalf("MarshalIndent error: %v", err) + } + if err := os.WriteFile(config.GetSettingPath(), data, 0644); err != nil { + t.Fatalf("WriteFile error: %v", err) + } +} + +func TestRequireMasterRejectsWorker(t *testing.T) { + setupTestDB(t) + writeNodeGuardSettings(t, map[string]any{ + "dbType": "mariadb", + "nodeRole": "worker", + "nodeId": "worker-1", + }) + + if err := RequireMaster(); err == nil { + t.Fatal("expected worker mode to be rejected") + } +} + +func TestRequireMasterAllowsMaster(t *testing.T) { + setupTestDB(t) + writeNodeGuardSettings(t, map[string]any{ + "dbType": "mariadb", + "nodeRole": "master", + }) + + if err := RequireMaster(); err != nil { + t.Fatalf("expected master mode to pass: %v", err) + } +} + +func TestBumpSharedAccountsVersionRollsBackWithTransaction(t *testing.T) { + setupTestDB(t) + + tx := database.GetDB().Begin() + if err := database.BumpSharedAccountsVersion(tx); err != nil { + t.Fatalf("BumpSharedAccountsVersion error: %v", err) + } + tx.Rollback() + + version, err := database.GetSharedAccountsVersion(database.GetDB()) + if err != nil { + t.Fatalf("GetSharedAccountsVersion error: %v", err) + } + if version != 0 { + t.Fatalf("expected rolled-back version to remain 0, got %d", version) + } +}