feat: guard shared writes and bump version transactionally

This commit is contained in:
Sora39831 2026-04-10 11:22:49 +08:00
parent fd0af148cb
commit 34b9f01d0a
3 changed files with 270 additions and 27 deletions

View file

@ -26,6 +26,14 @@ type InboundService struct {
xrayApi xray.XrayAPI
}
func ensureSharedWriteAllowed() error {
return RequireMaster()
}
func bumpSharedVersion(tx *gorm.DB) error {
return database.BumpSharedAccountsVersion(tx)
}
// GetInbounds retrieves all inbounds for a specific user.
// Returns a slice of inbound models with their associated client statistics.
func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
@ -228,6 +236,10 @@ func (s *InboundService) checkEmailExistInInbound(inbound *model.Inbound, email
// then saves the inbound to the database and optionally adds it to the running Xray instance.
// Returns the created inbound, whether Xray needs restart, and any error.
func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return inbound, false, err
}
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, 0)
if err != nil {
return inbound, false, err
@ -311,6 +323,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
} else {
return inbound, false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return inbound, false, err
}
needRestart := false
if inbound.Enable {
@ -337,6 +353,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
// It removes the inbound from the database and the running Xray instance if active.
// Returns whether Xray needs restart and any error.
func (s *InboundService) DelInbound(id int) (bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return false, err
}
db := database.GetDB()
var tag string
@ -356,11 +376,6 @@ func (s *InboundService) DelInbound(id int) (bool, error) {
logger.Debug("No enabled inbound founded to removing by api", tag)
}
// Delete client traffics of inbounds
err := db.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error
if err != nil {
return false, err
}
inbound, err := s.GetInbound(id)
if err != nil {
return false, err
@ -369,14 +384,35 @@ func (s *InboundService) DelInbound(id int) (bool, error) {
if err != nil {
return false, err
}
tx := db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
err = tx.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error
if err != nil {
return false, err
}
for _, client := range clients {
err := s.DelClientIPs(db, client.Email)
err := s.DelClientIPs(tx, client.Email)
if err != nil {
return false, err
}
}
return needRestart, db.Delete(model.Inbound{}, id).Error
err = tx.Delete(model.Inbound{}, id).Error
if err != nil {
return false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return false, err
}
return needRestart, nil
}
func (s *InboundService) GetInbound(id int) (*model.Inbound, error) {
@ -472,6 +508,10 @@ func (s *InboundService) DelInboundClientByEmailForUser(userID int, isAdmin bool
// It validates changes, updates the database, and syncs with the running Xray instance.
// Returns the updated inbound, whether Xray needs restart, and any error.
func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return inbound, false, err
}
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, inbound.Id)
if err != nil {
return inbound, false, err
@ -604,7 +644,15 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound,
}
s.xrayApi.Close()
return inbound, needRestart, tx.Save(oldInbound).Error
err = tx.Save(oldInbound).Error
if err != nil {
return inbound, false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return inbound, false, err
}
return inbound, needRestart, nil
}
func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inbound, newInbound *model.Inbound) error {
@ -653,6 +701,10 @@ func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inb
}
func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return false, err
}
clients, err := s.GetClients(data)
if err != nil {
return false, err
@ -773,10 +825,22 @@ func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
}
s.xrayApi.Close()
return needRestart, tx.Save(oldInbound).Error
err = tx.Save(oldInbound).Error
if err != nil {
return false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return false, err
}
return needRestart, nil
}
func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return false, err
}
oldInbound, err := s.GetInbound(inboundId)
if err != nil {
logger.Error("Load Old Data Error")
@ -824,8 +888,16 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
oldInbound.Settings = string(newSettings)
db := database.GetDB()
tx := db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
err = s.DelClientIPs(db, email)
err = s.DelClientIPs(tx, email)
if err != nil {
logger.Error("Error in delete client IPs")
return false, err
@ -841,7 +913,7 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
return false, err
}
notDepleted = err == nil && traffic.Enable
err = s.DelClientStat(db, inboundId, email)
err = s.DelClientStat(tx, inboundId, email)
if err != nil {
logger.Error("Delete stats Data Error")
return false, err
@ -863,10 +935,22 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
s.xrayApi.Close()
}
}
return needRestart, db.Save(oldInbound).Error
err = tx.Save(oldInbound).Error
if err != nil {
return false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return false, err
}
return needRestart, nil
}
func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId string) (bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return false, err
}
// TODO: check if TrafficReset field is updating
clients, err := s.GetClients(data)
if err != nil {
@ -1049,7 +1133,15 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin
logger.Debug("Client old email not found")
needRestart = true
}
return needRestart, tx.Save(oldInbound).Error
err = tx.Save(oldInbound).Error
if err != nil {
return false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return false, err
}
return needRestart, nil
}
func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
@ -1895,6 +1987,10 @@ func (s *InboundService) ResetClientTrafficByEmail(clientEmail string) error {
}
func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return false, err
}
needRestart := false
traffic, err := s.GetClientTrafficByEmail(clientEmail)
@ -1948,7 +2044,20 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e
traffic.Enable = true
db := database.GetDB()
err = db.Save(traffic).Error
tx := db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
err = tx.Save(traffic).Error
if err != nil {
return false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return false, err
}
@ -1957,6 +2066,10 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e
}
func (s *InboundService) ResetAllClientTraffics(id int) error {
if err := ensureSharedWriteAllowed(); err != nil {
return err
}
db := database.GetDB()
now := time.Now().Unix() * 1000
@ -1988,23 +2101,35 @@ func (s *InboundService) ResetAllClientTraffics(id int) error {
result = tx.Model(model.Inbound{}).
Where(inboundWhereText, id).
Update("last_traffic_reset_time", now)
return result.Error
if result.Error != nil {
return result.Error
}
return bumpSharedVersion(tx)
})
}
func (s *InboundService) ResetAllTraffics() error {
if err := ensureSharedWriteAllowed(); err != nil {
return err
}
db := database.GetDB()
result := db.Model(model.Inbound{}).
Where("user_id > ?", 0).
Updates(map[string]any{"up": 0, "down": 0})
err := result.Error
return err
return db.Transaction(func(tx *gorm.DB) error {
result := tx.Model(model.Inbound{}).
Where("user_id > ?", 0).
Updates(map[string]any{"up": 0, "down": 0})
if result.Error != nil {
return result.Error
}
return bumpSharedVersion(tx)
})
}
func (s *InboundService) DelDepletedClients(id int) (err error) {
if err := ensureSharedWriteAllowed(); err != nil {
return err
}
db := database.GetDB()
tx := db.Begin()
defer func() {
@ -2085,6 +2210,12 @@ func (s *InboundService) DelDepletedClients(id int) (err error) {
if err != nil {
return err
}
if len(depletedClients) > 0 {
err = bumpSharedVersion(tx)
if err != nil {
return err
}
}
return nil
}
@ -2551,6 +2682,10 @@ func (s *InboundService) FilterAndSortClientEmails(emails []string) ([]string, [
return validEmails, extraEmails, nil
}
func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (bool, error) {
if err := ensureSharedWriteAllowed(); err != nil {
return false, err
}
oldInbound, err := s.GetInbound(inboundId)
if err != nil {
logger.Error("Load Old Data Error")
@ -2601,9 +2736,17 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
oldInbound.Settings = string(newSettings)
db := database.GetDB()
tx := db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
// remove IP bindings
if err := s.DelClientIPs(db, email); err != nil {
if err := s.DelClientIPs(tx, email); err != nil {
logger.Error("Error in delete client IPs")
return false, err
}
@ -2617,7 +2760,7 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
return false, err
}
if traffic != nil {
if err := s.DelClientStat(db, inboundId, email); err != nil {
if err := s.DelClientStat(tx, inboundId, email); err != nil {
logger.Error("Delete stats Data Error")
return false, err
}
@ -2640,5 +2783,13 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
}
}
return needRestart, db.Save(oldInbound).Error
err = tx.Save(oldInbound).Error
if err != nil {
return false, err
}
err = bumpSharedVersion(tx)
if err != nil {
return false, err
}
return needRestart, nil
}

28
web/service/node_guard.go Normal file
View file

@ -0,0 +1,28 @@
package service
import (
"errors"
"github.com/mhsanaei/3x-ui/v2/config"
)
var ErrSharedWriteRequiresMaster = errors.New("shared-account writes are only allowed on master nodes")
func IsWorker() bool {
return config.GetNodeConfigFromJSON().Role == config.NodeRoleWorker
}
func IsMaster() bool {
return !IsWorker()
}
func RequireMaster() error {
if IsWorker() {
return ErrSharedWriteRequiresMaster
}
return nil
}
func IsSharedModeEnabled() bool {
return config.GetDBConfigFromJSON().Type == "mariadb"
}

View file

@ -0,0 +1,64 @@
package service
import (
"encoding/json"
"os"
"testing"
"github.com/mhsanaei/3x-ui/v2/config"
"github.com/mhsanaei/3x-ui/v2/database"
)
func writeNodeGuardSettings(t *testing.T, settings map[string]any) {
t.Helper()
data, err := json.MarshalIndent(settings, "", " ")
if err != nil {
t.Fatalf("MarshalIndent error: %v", err)
}
if err := os.WriteFile(config.GetSettingPath(), data, 0644); err != nil {
t.Fatalf("WriteFile error: %v", err)
}
}
func TestRequireMasterRejectsWorker(t *testing.T) {
setupTestDB(t)
writeNodeGuardSettings(t, map[string]any{
"dbType": "mariadb",
"nodeRole": "worker",
"nodeId": "worker-1",
})
if err := RequireMaster(); err == nil {
t.Fatal("expected worker mode to be rejected")
}
}
func TestRequireMasterAllowsMaster(t *testing.T) {
setupTestDB(t)
writeNodeGuardSettings(t, map[string]any{
"dbType": "mariadb",
"nodeRole": "master",
})
if err := RequireMaster(); err != nil {
t.Fatalf("expected master mode to pass: %v", err)
}
}
func TestBumpSharedAccountsVersionRollsBackWithTransaction(t *testing.T) {
setupTestDB(t)
tx := database.GetDB().Begin()
if err := database.BumpSharedAccountsVersion(tx); err != nil {
t.Fatalf("BumpSharedAccountsVersion error: %v", err)
}
tx.Rollback()
version, err := database.GetSharedAccountsVersion(database.GetDB())
if err != nil {
t.Fatalf("GetSharedAccountsVersion error: %v", err)
}
if version != 0 {
t.Fatalf("expected rolled-back version to remain 0, got %d", version)
}
}