mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-06 21:24:10 +00:00
feat: guard shared writes and bump version transactionally
This commit is contained in:
parent
fd0af148cb
commit
34b9f01d0a
3 changed files with 270 additions and 27 deletions
|
|
@ -26,6 +26,14 @@ type InboundService struct {
|
|||
xrayApi xray.XrayAPI
|
||||
}
|
||||
|
||||
func ensureSharedWriteAllowed() error {
|
||||
return RequireMaster()
|
||||
}
|
||||
|
||||
func bumpSharedVersion(tx *gorm.DB) error {
|
||||
return database.BumpSharedAccountsVersion(tx)
|
||||
}
|
||||
|
||||
// GetInbounds retrieves all inbounds for a specific user.
|
||||
// Returns a slice of inbound models with their associated client statistics.
|
||||
func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
|
||||
|
|
@ -228,6 +236,10 @@ func (s *InboundService) checkEmailExistInInbound(inbound *model.Inbound, email
|
|||
// then saves the inbound to the database and optionally adds it to the running Xray instance.
|
||||
// Returns the created inbound, whether Xray needs restart, and any error.
|
||||
func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return inbound, false, err
|
||||
}
|
||||
|
||||
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, 0)
|
||||
if err != nil {
|
||||
return inbound, false, err
|
||||
|
|
@ -311,6 +323,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
|
|||
} else {
|
||||
return inbound, false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return inbound, false, err
|
||||
}
|
||||
|
||||
needRestart := false
|
||||
if inbound.Enable {
|
||||
|
|
@ -337,6 +353,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
|
|||
// It removes the inbound from the database and the running Xray instance if active.
|
||||
// Returns whether Xray needs restart and any error.
|
||||
func (s *InboundService) DelInbound(id int) (bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
db := database.GetDB()
|
||||
|
||||
var tag string
|
||||
|
|
@ -356,11 +376,6 @@ func (s *InboundService) DelInbound(id int) (bool, error) {
|
|||
logger.Debug("No enabled inbound founded to removing by api", tag)
|
||||
}
|
||||
|
||||
// Delete client traffics of inbounds
|
||||
err := db.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
inbound, err := s.GetInbound(id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
@ -369,14 +384,35 @@ func (s *InboundService) DelInbound(id int) (bool, error) {
|
|||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
err = tx.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, client := range clients {
|
||||
err := s.DelClientIPs(db, client.Email)
|
||||
err := s.DelClientIPs(tx, client.Email)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return needRestart, db.Delete(model.Inbound{}, id).Error
|
||||
err = tx.Delete(model.Inbound{}, id).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return needRestart, nil
|
||||
}
|
||||
|
||||
func (s *InboundService) GetInbound(id int) (*model.Inbound, error) {
|
||||
|
|
@ -472,6 +508,10 @@ func (s *InboundService) DelInboundClientByEmailForUser(userID int, isAdmin bool
|
|||
// It validates changes, updates the database, and syncs with the running Xray instance.
|
||||
// Returns the updated inbound, whether Xray needs restart, and any error.
|
||||
func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return inbound, false, err
|
||||
}
|
||||
|
||||
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, inbound.Id)
|
||||
if err != nil {
|
||||
return inbound, false, err
|
||||
|
|
@ -604,7 +644,15 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound,
|
|||
}
|
||||
s.xrayApi.Close()
|
||||
|
||||
return inbound, needRestart, tx.Save(oldInbound).Error
|
||||
err = tx.Save(oldInbound).Error
|
||||
if err != nil {
|
||||
return inbound, false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return inbound, false, err
|
||||
}
|
||||
return inbound, needRestart, nil
|
||||
}
|
||||
|
||||
func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inbound, newInbound *model.Inbound) error {
|
||||
|
|
@ -653,6 +701,10 @@ func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inb
|
|||
}
|
||||
|
||||
func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
clients, err := s.GetClients(data)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
@ -773,10 +825,22 @@ func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
|
|||
}
|
||||
s.xrayApi.Close()
|
||||
|
||||
return needRestart, tx.Save(oldInbound).Error
|
||||
err = tx.Save(oldInbound).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return needRestart, nil
|
||||
}
|
||||
|
||||
func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldInbound, err := s.GetInbound(inboundId)
|
||||
if err != nil {
|
||||
logger.Error("Load Old Data Error")
|
||||
|
|
@ -824,8 +888,16 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
|
|||
oldInbound.Settings = string(newSettings)
|
||||
|
||||
db := database.GetDB()
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
err = s.DelClientIPs(db, email)
|
||||
err = s.DelClientIPs(tx, email)
|
||||
if err != nil {
|
||||
logger.Error("Error in delete client IPs")
|
||||
return false, err
|
||||
|
|
@ -841,7 +913,7 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
|
|||
return false, err
|
||||
}
|
||||
notDepleted = err == nil && traffic.Enable
|
||||
err = s.DelClientStat(db, inboundId, email)
|
||||
err = s.DelClientStat(tx, inboundId, email)
|
||||
if err != nil {
|
||||
logger.Error("Delete stats Data Error")
|
||||
return false, err
|
||||
|
|
@ -863,10 +935,22 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
|
|||
s.xrayApi.Close()
|
||||
}
|
||||
}
|
||||
return needRestart, db.Save(oldInbound).Error
|
||||
err = tx.Save(oldInbound).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return needRestart, nil
|
||||
}
|
||||
|
||||
func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId string) (bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// TODO: check if TrafficReset field is updating
|
||||
clients, err := s.GetClients(data)
|
||||
if err != nil {
|
||||
|
|
@ -1049,7 +1133,15 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin
|
|||
logger.Debug("Client old email not found")
|
||||
needRestart = true
|
||||
}
|
||||
return needRestart, tx.Save(oldInbound).Error
|
||||
err = tx.Save(oldInbound).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return needRestart, nil
|
||||
}
|
||||
|
||||
func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
|
||||
|
|
@ -1895,6 +1987,10 @@ func (s *InboundService) ResetClientTrafficByEmail(clientEmail string) error {
|
|||
}
|
||||
|
||||
func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
needRestart := false
|
||||
|
||||
traffic, err := s.GetClientTrafficByEmail(clientEmail)
|
||||
|
|
@ -1948,7 +2044,20 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e
|
|||
traffic.Enable = true
|
||||
|
||||
db := database.GetDB()
|
||||
err = db.Save(traffic).Error
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
err = tx.Save(traffic).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -1957,6 +2066,10 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e
|
|||
}
|
||||
|
||||
func (s *InboundService) ResetAllClientTraffics(id int) error {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db := database.GetDB()
|
||||
now := time.Now().Unix() * 1000
|
||||
|
||||
|
|
@ -1988,23 +2101,35 @@ func (s *InboundService) ResetAllClientTraffics(id int) error {
|
|||
result = tx.Model(model.Inbound{}).
|
||||
Where(inboundWhereText, id).
|
||||
Update("last_traffic_reset_time", now)
|
||||
|
||||
return result.Error
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
return bumpSharedVersion(tx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InboundService) ResetAllTraffics() error {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db := database.GetDB()
|
||||
|
||||
result := db.Model(model.Inbound{}).
|
||||
Where("user_id > ?", 0).
|
||||
Updates(map[string]any{"up": 0, "down": 0})
|
||||
|
||||
err := result.Error
|
||||
return err
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Model(model.Inbound{}).
|
||||
Where("user_id > ?", 0).
|
||||
Updates(map[string]any{"up": 0, "down": 0})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
return bumpSharedVersion(tx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InboundService) DelDepletedClients(id int) (err error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db := database.GetDB()
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
|
|
@ -2085,6 +2210,12 @@ func (s *InboundService) DelDepletedClients(id int) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(depletedClients) > 0 {
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -2551,6 +2682,10 @@ func (s *InboundService) FilterAndSortClientEmails(emails []string) ([]string, [
|
|||
return validEmails, extraEmails, nil
|
||||
}
|
||||
func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (bool, error) {
|
||||
if err := ensureSharedWriteAllowed(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldInbound, err := s.GetInbound(inboundId)
|
||||
if err != nil {
|
||||
logger.Error("Load Old Data Error")
|
||||
|
|
@ -2601,9 +2736,17 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
|
|||
oldInbound.Settings = string(newSettings)
|
||||
|
||||
db := database.GetDB()
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
} else {
|
||||
tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
// remove IP bindings
|
||||
if err := s.DelClientIPs(db, email); err != nil {
|
||||
if err := s.DelClientIPs(tx, email); err != nil {
|
||||
logger.Error("Error in delete client IPs")
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -2617,7 +2760,7 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
|
|||
return false, err
|
||||
}
|
||||
if traffic != nil {
|
||||
if err := s.DelClientStat(db, inboundId, email); err != nil {
|
||||
if err := s.DelClientStat(tx, inboundId, email); err != nil {
|
||||
logger.Error("Delete stats Data Error")
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -2640,5 +2783,13 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
|
|||
}
|
||||
}
|
||||
|
||||
return needRestart, db.Save(oldInbound).Error
|
||||
err = tx.Save(oldInbound).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = bumpSharedVersion(tx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return needRestart, nil
|
||||
}
|
||||
|
|
|
|||
28
web/service/node_guard.go
Normal file
28
web/service/node_guard.go
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v2/config"
|
||||
)
|
||||
|
||||
var ErrSharedWriteRequiresMaster = errors.New("shared-account writes are only allowed on master nodes")
|
||||
|
||||
func IsWorker() bool {
|
||||
return config.GetNodeConfigFromJSON().Role == config.NodeRoleWorker
|
||||
}
|
||||
|
||||
func IsMaster() bool {
|
||||
return !IsWorker()
|
||||
}
|
||||
|
||||
func RequireMaster() error {
|
||||
if IsWorker() {
|
||||
return ErrSharedWriteRequiresMaster
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsSharedModeEnabled() bool {
|
||||
return config.GetDBConfigFromJSON().Type == "mariadb"
|
||||
}
|
||||
64
web/service/node_guard_test.go
Normal file
64
web/service/node_guard_test.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v2/config"
|
||||
"github.com/mhsanaei/3x-ui/v2/database"
|
||||
)
|
||||
|
||||
func writeNodeGuardSettings(t *testing.T, settings map[string]any) {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalIndent error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(config.GetSettingPath(), data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireMasterRejectsWorker(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
writeNodeGuardSettings(t, map[string]any{
|
||||
"dbType": "mariadb",
|
||||
"nodeRole": "worker",
|
||||
"nodeId": "worker-1",
|
||||
})
|
||||
|
||||
if err := RequireMaster(); err == nil {
|
||||
t.Fatal("expected worker mode to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireMasterAllowsMaster(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
writeNodeGuardSettings(t, map[string]any{
|
||||
"dbType": "mariadb",
|
||||
"nodeRole": "master",
|
||||
})
|
||||
|
||||
if err := RequireMaster(); err != nil {
|
||||
t.Fatalf("expected master mode to pass: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBumpSharedAccountsVersionRollsBackWithTransaction(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
tx := database.GetDB().Begin()
|
||||
if err := database.BumpSharedAccountsVersion(tx); err != nil {
|
||||
t.Fatalf("BumpSharedAccountsVersion error: %v", err)
|
||||
}
|
||||
tx.Rollback()
|
||||
|
||||
version, err := database.GetSharedAccountsVersion(database.GetDB())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSharedAccountsVersion error: %v", err)
|
||||
}
|
||||
if version != 0 {
|
||||
t.Fatalf("expected rolled-back version to remain 0, got %d", version)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue