mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-08 22:24:15 +00:00
feat: guard shared writes and bump version transactionally
This commit is contained in:
parent
fd0af148cb
commit
34b9f01d0a
3 changed files with 270 additions and 27 deletions
|
|
@ -26,6 +26,14 @@ type InboundService struct {
|
||||||
xrayApi xray.XrayAPI
|
xrayApi xray.XrayAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureSharedWriteAllowed() error {
|
||||||
|
return RequireMaster()
|
||||||
|
}
|
||||||
|
|
||||||
|
func bumpSharedVersion(tx *gorm.DB) error {
|
||||||
|
return database.BumpSharedAccountsVersion(tx)
|
||||||
|
}
|
||||||
|
|
||||||
// GetInbounds retrieves all inbounds for a specific user.
|
// GetInbounds retrieves all inbounds for a specific user.
|
||||||
// Returns a slice of inbound models with their associated client statistics.
|
// Returns a slice of inbound models with their associated client statistics.
|
||||||
func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
|
func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
|
||||||
|
|
@ -228,6 +236,10 @@ func (s *InboundService) checkEmailExistInInbound(inbound *model.Inbound, email
|
||||||
// then saves the inbound to the database and optionally adds it to the running Xray instance.
|
// then saves the inbound to the database and optionally adds it to the running Xray instance.
|
||||||
// Returns the created inbound, whether Xray needs restart, and any error.
|
// Returns the created inbound, whether Xray needs restart, and any error.
|
||||||
func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
|
func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return inbound, false, err
|
||||||
|
}
|
||||||
|
|
||||||
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, 0)
|
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return inbound, false, err
|
return inbound, false, err
|
||||||
|
|
@ -311,6 +323,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
|
||||||
} else {
|
} else {
|
||||||
return inbound, false, err
|
return inbound, false, err
|
||||||
}
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return inbound, false, err
|
||||||
|
}
|
||||||
|
|
||||||
needRestart := false
|
needRestart := false
|
||||||
if inbound.Enable {
|
if inbound.Enable {
|
||||||
|
|
@ -337,6 +353,10 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
|
||||||
// It removes the inbound from the database and the running Xray instance if active.
|
// It removes the inbound from the database and the running Xray instance if active.
|
||||||
// Returns whether Xray needs restart and any error.
|
// Returns whether Xray needs restart and any error.
|
||||||
func (s *InboundService) DelInbound(id int) (bool, error) {
|
func (s *InboundService) DelInbound(id int) (bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
|
|
||||||
var tag string
|
var tag string
|
||||||
|
|
@ -356,11 +376,6 @@ func (s *InboundService) DelInbound(id int) (bool, error) {
|
||||||
logger.Debug("No enabled inbound founded to removing by api", tag)
|
logger.Debug("No enabled inbound founded to removing by api", tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete client traffics of inbounds
|
|
||||||
err := db.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
inbound, err := s.GetInbound(id)
|
inbound, err := s.GetInbound(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|
@ -369,14 +384,35 @@ func (s *InboundService) DelInbound(id int) (bool, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
} else {
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = tx.Where("inbound_id = ?", id).Delete(xray.ClientTraffic{}).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
for _, client := range clients {
|
for _, client := range clients {
|
||||||
err := s.DelClientIPs(db, client.Email)
|
err := s.DelClientIPs(tx, client.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
err = tx.Delete(model.Inbound{}, id).Error
|
||||||
return needRestart, db.Delete(model.Inbound{}, id).Error
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return needRestart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) GetInbound(id int) (*model.Inbound, error) {
|
func (s *InboundService) GetInbound(id int) (*model.Inbound, error) {
|
||||||
|
|
@ -472,6 +508,10 @@ func (s *InboundService) DelInboundClientByEmailForUser(userID int, isAdmin bool
|
||||||
// It validates changes, updates the database, and syncs with the running Xray instance.
|
// It validates changes, updates the database, and syncs with the running Xray instance.
|
||||||
// Returns the updated inbound, whether Xray needs restart, and any error.
|
// Returns the updated inbound, whether Xray needs restart, and any error.
|
||||||
func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
|
func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return inbound, false, err
|
||||||
|
}
|
||||||
|
|
||||||
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, inbound.Id)
|
exist, err := s.checkPortExist(inbound.Listen, inbound.Port, inbound.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return inbound, false, err
|
return inbound, false, err
|
||||||
|
|
@ -604,7 +644,15 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound,
|
||||||
}
|
}
|
||||||
s.xrayApi.Close()
|
s.xrayApi.Close()
|
||||||
|
|
||||||
return inbound, needRestart, tx.Save(oldInbound).Error
|
err = tx.Save(oldInbound).Error
|
||||||
|
if err != nil {
|
||||||
|
return inbound, false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return inbound, false, err
|
||||||
|
}
|
||||||
|
return inbound, needRestart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inbound, newInbound *model.Inbound) error {
|
func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inbound, newInbound *model.Inbound) error {
|
||||||
|
|
@ -653,6 +701,10 @@ func (s *InboundService) updateClientTraffics(tx *gorm.DB, oldInbound *model.Inb
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
|
func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
clients, err := s.GetClients(data)
|
clients, err := s.GetClients(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|
@ -773,10 +825,22 @@ func (s *InboundService) AddInboundClient(data *model.Inbound) (bool, error) {
|
||||||
}
|
}
|
||||||
s.xrayApi.Close()
|
s.xrayApi.Close()
|
||||||
|
|
||||||
return needRestart, tx.Save(oldInbound).Error
|
err = tx.Save(oldInbound).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return needRestart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, error) {
|
func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
oldInbound, err := s.GetInbound(inboundId)
|
oldInbound, err := s.GetInbound(inboundId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Load Old Data Error")
|
logger.Error("Load Old Data Error")
|
||||||
|
|
@ -824,8 +888,16 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
|
||||||
oldInbound.Settings = string(newSettings)
|
oldInbound.Settings = string(newSettings)
|
||||||
|
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
|
tx := db.Begin()
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
} else {
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err = s.DelClientIPs(db, email)
|
err = s.DelClientIPs(tx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error in delete client IPs")
|
logger.Error("Error in delete client IPs")
|
||||||
return false, err
|
return false, err
|
||||||
|
|
@ -841,7 +913,7 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
notDepleted = err == nil && traffic.Enable
|
notDepleted = err == nil && traffic.Enable
|
||||||
err = s.DelClientStat(db, inboundId, email)
|
err = s.DelClientStat(tx, inboundId, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Delete stats Data Error")
|
logger.Error("Delete stats Data Error")
|
||||||
return false, err
|
return false, err
|
||||||
|
|
@ -863,10 +935,22 @@ func (s *InboundService) DelInboundClient(inboundId int, clientId string) (bool,
|
||||||
s.xrayApi.Close()
|
s.xrayApi.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return needRestart, db.Save(oldInbound).Error
|
err = tx.Save(oldInbound).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return needRestart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId string) (bool, error) {
|
func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId string) (bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: check if TrafficReset field is updating
|
// TODO: check if TrafficReset field is updating
|
||||||
clients, err := s.GetClients(data)
|
clients, err := s.GetClients(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1049,7 +1133,15 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin
|
||||||
logger.Debug("Client old email not found")
|
logger.Debug("Client old email not found")
|
||||||
needRestart = true
|
needRestart = true
|
||||||
}
|
}
|
||||||
return needRestart, tx.Save(oldInbound).Error
|
err = tx.Save(oldInbound).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return needRestart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
|
func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
|
||||||
|
|
@ -1895,6 +1987,10 @@ func (s *InboundService) ResetClientTrafficByEmail(clientEmail string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, error) {
|
func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
needRestart := false
|
needRestart := false
|
||||||
|
|
||||||
traffic, err := s.GetClientTrafficByEmail(clientEmail)
|
traffic, err := s.GetClientTrafficByEmail(clientEmail)
|
||||||
|
|
@ -1948,7 +2044,20 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e
|
||||||
traffic.Enable = true
|
traffic.Enable = true
|
||||||
|
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
err = db.Save(traffic).Error
|
tx := db.Begin()
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
} else {
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = tx.Save(traffic).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
@ -1957,6 +2066,10 @@ func (s *InboundService) ResetClientTraffic(id int, clientEmail string) (bool, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) ResetAllClientTraffics(id int) error {
|
func (s *InboundService) ResetAllClientTraffics(id int) error {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
now := time.Now().Unix() * 1000
|
now := time.Now().Unix() * 1000
|
||||||
|
|
||||||
|
|
@ -1988,23 +2101,35 @@ func (s *InboundService) ResetAllClientTraffics(id int) error {
|
||||||
result = tx.Model(model.Inbound{}).
|
result = tx.Model(model.Inbound{}).
|
||||||
Where(inboundWhereText, id).
|
Where(inboundWhereText, id).
|
||||||
Update("last_traffic_reset_time", now)
|
Update("last_traffic_reset_time", now)
|
||||||
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
|
}
|
||||||
|
return bumpSharedVersion(tx)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) ResetAllTraffics() error {
|
func (s *InboundService) ResetAllTraffics() error {
|
||||||
db := database.GetDB()
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
|
||||||
result := db.Model(model.Inbound{}).
|
|
||||||
Where("user_id > ?", 0).
|
|
||||||
Updates(map[string]any{"up": 0, "down": 0})
|
|
||||||
|
|
||||||
err := result.Error
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db := database.GetDB()
|
||||||
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Model(model.Inbound{}).
|
||||||
|
Where("user_id > ?", 0).
|
||||||
|
Updates(map[string]any{"up": 0, "down": 0})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return bumpSharedVersion(tx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *InboundService) DelDepletedClients(id int) (err error) {
|
func (s *InboundService) DelDepletedClients(id int) (err error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
tx := db.Begin()
|
tx := db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
@ -2085,6 +2210,12 @@ func (s *InboundService) DelDepletedClients(id int) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if len(depletedClients) > 0 {
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -2551,6 +2682,10 @@ func (s *InboundService) FilterAndSortClientEmails(emails []string) ([]string, [
|
||||||
return validEmails, extraEmails, nil
|
return validEmails, extraEmails, nil
|
||||||
}
|
}
|
||||||
func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (bool, error) {
|
func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (bool, error) {
|
||||||
|
if err := ensureSharedWriteAllowed(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
oldInbound, err := s.GetInbound(inboundId)
|
oldInbound, err := s.GetInbound(inboundId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Load Old Data Error")
|
logger.Error("Load Old Data Error")
|
||||||
|
|
@ -2601,9 +2736,17 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
|
||||||
oldInbound.Settings = string(newSettings)
|
oldInbound.Settings = string(newSettings)
|
||||||
|
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
|
tx := db.Begin()
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
} else {
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// remove IP bindings
|
// remove IP bindings
|
||||||
if err := s.DelClientIPs(db, email); err != nil {
|
if err := s.DelClientIPs(tx, email); err != nil {
|
||||||
logger.Error("Error in delete client IPs")
|
logger.Error("Error in delete client IPs")
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
@ -2617,7 +2760,7 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if traffic != nil {
|
if traffic != nil {
|
||||||
if err := s.DelClientStat(db, inboundId, email); err != nil {
|
if err := s.DelClientStat(tx, inboundId, email); err != nil {
|
||||||
logger.Error("Delete stats Data Error")
|
logger.Error("Delete stats Data Error")
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
@ -2640,5 +2783,13 @@ func (s *InboundService) DelInboundClientByEmail(inboundId int, email string) (b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return needRestart, db.Save(oldInbound).Error
|
err = tx.Save(oldInbound).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
err = bumpSharedVersion(tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return needRestart, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
28
web/service/node_guard.go
Normal file
28
web/service/node_guard.go
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/mhsanaei/3x-ui/v2/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrSharedWriteRequiresMaster = errors.New("shared-account writes are only allowed on master nodes")
|
||||||
|
|
||||||
|
func IsWorker() bool {
|
||||||
|
return config.GetNodeConfigFromJSON().Role == config.NodeRoleWorker
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsMaster() bool {
|
||||||
|
return !IsWorker()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequireMaster() error {
|
||||||
|
if IsWorker() {
|
||||||
|
return ErrSharedWriteRequiresMaster
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsSharedModeEnabled() bool {
|
||||||
|
return config.GetDBConfigFromJSON().Type == "mariadb"
|
||||||
|
}
|
||||||
64
web/service/node_guard_test.go
Normal file
64
web/service/node_guard_test.go
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mhsanaei/3x-ui/v2/config"
|
||||||
|
"github.com/mhsanaei/3x-ui/v2/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeNodeGuardSettings(t *testing.T, settings map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
data, err := json.MarshalIndent(settings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalIndent error: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(config.GetSettingPath(), data, 0644); err != nil {
|
||||||
|
t.Fatalf("WriteFile error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireMasterRejectsWorker(t *testing.T) {
|
||||||
|
setupTestDB(t)
|
||||||
|
writeNodeGuardSettings(t, map[string]any{
|
||||||
|
"dbType": "mariadb",
|
||||||
|
"nodeRole": "worker",
|
||||||
|
"nodeId": "worker-1",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := RequireMaster(); err == nil {
|
||||||
|
t.Fatal("expected worker mode to be rejected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireMasterAllowsMaster(t *testing.T) {
|
||||||
|
setupTestDB(t)
|
||||||
|
writeNodeGuardSettings(t, map[string]any{
|
||||||
|
"dbType": "mariadb",
|
||||||
|
"nodeRole": "master",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := RequireMaster(); err != nil {
|
||||||
|
t.Fatalf("expected master mode to pass: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBumpSharedAccountsVersionRollsBackWithTransaction(t *testing.T) {
|
||||||
|
setupTestDB(t)
|
||||||
|
|
||||||
|
tx := database.GetDB().Begin()
|
||||||
|
if err := database.BumpSharedAccountsVersion(tx); err != nil {
|
||||||
|
t.Fatalf("BumpSharedAccountsVersion error: %v", err)
|
||||||
|
}
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
version, err := database.GetSharedAccountsVersion(database.GetDB())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSharedAccountsVersion error: %v", err)
|
||||||
|
}
|
||||||
|
if version != 0 {
|
||||||
|
t.Fatalf("expected rolled-back version to remain 0, got %d", version)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue