mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-07 05:34:17 +00:00
fix: harden migration and setting writes
This commit is contained in:
parent
e298996d77
commit
cc6d3daa3a
10 changed files with 301 additions and 160 deletions
|
|
@ -91,18 +91,24 @@ func runSeeders(isUsersEmpty bool) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if empty && isUsersEmpty {
|
||||
hashSeeder := &model.HistoryOfSeeders{
|
||||
SeederName: "UserPasswordHash",
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
if empty && isUsersEmpty {
|
||||
hashSeeder := &model.HistoryOfSeeders{
|
||||
SeederName: "UserPasswordHash",
|
||||
}
|
||||
return tx.Create(hashSeeder).Error
|
||||
}
|
||||
return db.Create(hashSeeder).Error
|
||||
} else {
|
||||
|
||||
var seedersHistory []string
|
||||
db.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory)
|
||||
if err := tx.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains(seedersHistory, "UserPasswordHash") && !isUsersEmpty {
|
||||
var users []model.User
|
||||
db.Find(&users)
|
||||
if err := tx.Find(&users).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
hashedPassword, err := crypto.HashPasswordAsBcrypt(user.Password)
|
||||
|
|
@ -110,13 +116,15 @@ func runSeeders(isUsersEmpty bool) error {
|
|||
log.Printf("Error hashing password for user '%s': %v", user.Username, err)
|
||||
return err
|
||||
}
|
||||
db.Model(&user).Update("password", hashedPassword)
|
||||
if err := tx.Model(&user).Update("password", hashedPassword).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hashSeeder := &model.HistoryOfSeeders{
|
||||
SeederName: "UserPasswordHash",
|
||||
}
|
||||
if err := db.Create(hashSeeder).Error; err != nil {
|
||||
if err := tx.Create(hashSeeder).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -125,21 +133,25 @@ func runSeeders(isUsersEmpty bool) error {
|
|||
// Drop the old unique index on client_traffics.email to allow
|
||||
// the same email across multiple inbounds
|
||||
dbType := config.GetDBTypeFromJSON()
|
||||
var execErr error
|
||||
if dbType == "mariadb" {
|
||||
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics")
|
||||
execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics").Error
|
||||
} else {
|
||||
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email")
|
||||
execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email").Error
|
||||
}
|
||||
if execErr != nil {
|
||||
return execErr
|
||||
}
|
||||
uniqueSeeder := &model.HistoryOfSeeders{
|
||||
SeederName: "RemoveClientTrafficEmailUnique",
|
||||
}
|
||||
if err := db.Create(uniqueSeeder).Error; err != nil {
|
||||
if err := tx.Create(uniqueSeeder).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// isTableEmpty returns true if the named table contains zero rows.
|
||||
|
|
|
|||
|
|
@ -212,3 +212,47 @@ func TestInitUser_OnlyOnce(t *testing.T) {
|
|||
t.Errorf("expected 1 user, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSeeders_DoesNotRecordHistoryWhenPasswordUpdateFails(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
if err := db.Exec("DELETE FROM history_of_seeders").Error; err != nil {
|
||||
t.Fatalf("clear seeders history failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Exec(`
|
||||
CREATE TRIGGER fail_user_password_update
|
||||
BEFORE UPDATE OF password ON users
|
||||
BEGIN
|
||||
SELECT RAISE(FAIL, 'boom');
|
||||
END;
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("create trigger failed: %v", err)
|
||||
}
|
||||
|
||||
err := runSeeders(false)
|
||||
if err == nil {
|
||||
t.Fatalf("expected runSeeders to fail when user password update fails")
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&model.HistoryOfSeeders{}).
|
||||
Where("seeder_name = ?", "UserPasswordHash").
|
||||
Count(&count).Error; err != nil {
|
||||
t.Fatalf("count seeder history failed: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("expected no UserPasswordHash history row after failed seeder, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingKey_IsUnique(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
if err := db.Create(&model.Setting{Key: "dup", Value: "one"}).Error; err != nil {
|
||||
t.Fatalf("first insert failed: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.Setting{Key: "dup", Value: "two"}).Error; err == nil {
|
||||
t.Fatal("expected duplicate setting key insert to fail")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig {
|
|||
// Setting stores key-value configuration settings for the 3x-ui panel.
|
||||
type Setting struct {
|
||||
Id int `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
|
||||
Key string `json:"key" form:"key"`
|
||||
Key string `json:"key" form:"key" gorm:"uniqueIndex"`
|
||||
Value string `json:"value" form:"value"`
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
import "testing"
|
||||
|
||||
func TestGenXrayInboundConfig_EmptyListen(t *testing.T) {
|
||||
in := &Inbound{
|
||||
|
|
|
|||
4
main.go
4
main.go
|
|
@ -411,7 +411,9 @@ func migrateDb() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Start migrating database...")
|
||||
inboundService.MigrateDB()
|
||||
if err := inboundService.MigrateDB(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Migration done!")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2333,167 +2333,170 @@ func (s *InboundService) SearchInbounds(query string) ([]*model.Inbound, error)
|
|||
return inbounds, nil
|
||||
}
|
||||
|
||||
func (s *InboundService) MigrationRequirements() {
|
||||
func (s *InboundService) MigrationRequirements() error {
|
||||
db := database.GetDB()
|
||||
tx := db.Begin()
|
||||
var err error
|
||||
defer func() {
|
||||
if err == nil {
|
||||
tx.Commit()
|
||||
if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil {
|
||||
logger.Warningf("VACUUM failed: %v", dbErr)
|
||||
}
|
||||
} else {
|
||||
tx.Rollback()
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(`
|
||||
UPDATE inbounds
|
||||
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
||||
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
||||
`).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Exec(`
|
||||
UPDATE client_traffics
|
||||
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
||||
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
||||
`).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}()
|
||||
|
||||
// Calculate and backfill all_time from up+down for inbounds and clients
|
||||
err = tx.Exec(`
|
||||
UPDATE inbounds
|
||||
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
||||
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
||||
`).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = tx.Exec(`
|
||||
UPDATE client_traffics
|
||||
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
||||
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
||||
`).Error
|
||||
var inbounds []*model.Inbound
|
||||
err := tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
for inboundIndex := range inbounds {
|
||||
settings := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(inbounds[inboundIndex].Settings), &settings); err != nil {
|
||||
return err
|
||||
}
|
||||
clients, ok := settings["clients"].([]any)
|
||||
if ok {
|
||||
var newClients []any
|
||||
for clientIndex := range clients {
|
||||
c, ok := clients[clientIndex].(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid client settings format for inbound %d", inbounds[inboundIndex].Id)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fix inbounds based problems
|
||||
var inbounds []*model.Inbound
|
||||
err = tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return
|
||||
}
|
||||
for inbound_index := range inbounds {
|
||||
settings := map[string]any{}
|
||||
json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings)
|
||||
clients, ok := settings["clients"].([]any)
|
||||
if ok {
|
||||
// Fix Client configuration problems
|
||||
var newClients []any
|
||||
for client_index := range clients {
|
||||
c := clients[client_index].(map[string]any)
|
||||
|
||||
// Add email='' if it is not exists
|
||||
if _, ok := c["email"]; !ok {
|
||||
c["email"] = ""
|
||||
}
|
||||
|
||||
// Convert string tgId to int64
|
||||
if _, ok := c["tgId"]; ok {
|
||||
var tgId any = c["tgId"]
|
||||
if tgIdStr, ok2 := tgId.(string); ok2 {
|
||||
tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64)
|
||||
if err == nil {
|
||||
c["tgId"] = tgIdInt64
|
||||
if _, ok := c["email"]; !ok {
|
||||
c["email"] = ""
|
||||
}
|
||||
if _, ok := c["tgId"]; ok {
|
||||
tgId := c["tgId"]
|
||||
if tgIdStr, ok2 := tgId.(string); ok2 {
|
||||
tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64)
|
||||
if err == nil {
|
||||
c["tgId"] = tgIdInt64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove "flow": "xtls-rprx-direct"
|
||||
if _, ok := c["flow"]; ok {
|
||||
if c["flow"] == "xtls-rprx-direct" {
|
||||
if _, ok := c["flow"]; ok && c["flow"] == "xtls-rprx-direct" {
|
||||
c["flow"] = ""
|
||||
}
|
||||
if _, ok := c["created_at"]; !ok {
|
||||
c["created_at"] = time.Now().Unix() * 1000
|
||||
}
|
||||
c["updated_at"] = time.Now().Unix() * 1000
|
||||
newClients = append(newClients, c)
|
||||
}
|
||||
// Backfill created_at and updated_at
|
||||
if _, ok := c["created_at"]; !ok {
|
||||
c["created_at"] = time.Now().Unix() * 1000
|
||||
settings["clients"] = newClients
|
||||
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c["updated_at"] = time.Now().Unix() * 1000
|
||||
newClients = append(newClients, any(c))
|
||||
inbounds[inboundIndex].Settings = string(modifiedSettings)
|
||||
}
|
||||
settings["clients"] = newClients
|
||||
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
|
||||
|
||||
modelClients, err := s.GetClients(inbounds[inboundIndex])
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
inbounds[inbound_index].Settings = string(modifiedSettings)
|
||||
}
|
||||
|
||||
// Add client traffic row for all clients which has email
|
||||
modelClients, err := s.GetClients(inbounds[inbound_index])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, modelClient := range modelClients {
|
||||
if len(modelClient.Email) > 0 {
|
||||
for _, modelClient := range modelClients {
|
||||
if len(modelClient.Email) == 0 {
|
||||
continue
|
||||
}
|
||||
var count int64
|
||||
tx.Model(xray.ClientTraffic{}).
|
||||
Where("inbound_id = ? AND email = ?", inbounds[inbound_index].Id, modelClient.Email).
|
||||
Count(&count)
|
||||
if err := tx.Model(xray.ClientTraffic{}).
|
||||
Where("inbound_id = ? AND email = ?", inbounds[inboundIndex].Id, modelClient.Email).
|
||||
Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
s.AddClientStat(tx, inbounds[inbound_index].Id, &modelClient)
|
||||
if err := s.AddClientStat(tx, inbounds[inboundIndex].Id, &modelClient); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tx.Save(inbounds)
|
||||
if err := tx.Save(inbounds).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove orphaned traffics
|
||||
tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{})
|
||||
var externalProxy []struct {
|
||||
Id int
|
||||
Port int
|
||||
StreamSettings []byte
|
||||
}
|
||||
if err := tx.Raw(`select id, port, stream_settings
|
||||
from inbounds
|
||||
WHERE protocol in ('vmess','vless','trojan')
|
||||
AND json_extract(stream_settings, '$.security') = 'tls'
|
||||
AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Migrate old MultiDomain to External Proxy
|
||||
var externalProxy []struct {
|
||||
Id int
|
||||
Port int
|
||||
StreamSettings []byte
|
||||
}
|
||||
err = tx.Raw(`select id, port, stream_settings
|
||||
from inbounds
|
||||
WHERE protocol in ('vmess','vless','trojan')
|
||||
AND json_extract(stream_settings, '$.security') = 'tls'
|
||||
AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error
|
||||
if err != nil || len(externalProxy) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ep := range externalProxy {
|
||||
var reverses any
|
||||
var stream map[string]any
|
||||
json.Unmarshal(ep.StreamSettings, &stream)
|
||||
if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok {
|
||||
if settings, ok := tlsSettings["settings"].(map[string]any); ok {
|
||||
if domains, ok := settings["domains"].([]any); ok {
|
||||
for _, domain := range domains {
|
||||
if domainMap, ok := domain.(map[string]any); ok {
|
||||
for _, ep := range externalProxy {
|
||||
var reverses any
|
||||
var stream map[string]any
|
||||
if err := json.Unmarshal(ep.StreamSettings, &stream); err != nil {
|
||||
return err
|
||||
}
|
||||
if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok {
|
||||
if settings, ok := tlsSettings["settings"].(map[string]any); ok {
|
||||
if domains, ok := settings["domains"].([]any); ok {
|
||||
for _, domain := range domains {
|
||||
domainMap, ok := domain.(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid tls domain settings format for inbound %d", ep.Id)
|
||||
}
|
||||
domainName, ok := domainMap["domain"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid tls domain name for inbound %d", ep.Id)
|
||||
}
|
||||
domainMap["forceTls"] = "same"
|
||||
domainMap["port"] = ep.Port
|
||||
domainMap["dest"] = domainMap["domain"].(string)
|
||||
domainMap["dest"] = domainName
|
||||
delete(domainMap, "domain")
|
||||
}
|
||||
}
|
||||
reverses = settings["domains"]
|
||||
delete(settings, "domains")
|
||||
}
|
||||
reverses = settings["domains"]
|
||||
delete(settings, "domains")
|
||||
}
|
||||
stream["externalProxy"] = reverses
|
||||
newStream, err := json.MarshalIndent(stream, " ", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stream["externalProxy"] = reverses
|
||||
newStream, _ := json.MarshalIndent(stream, " ", " ")
|
||||
tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream)
|
||||
|
||||
return tx.Raw(`UPDATE inbounds
|
||||
SET tag = REPLACE(tag, '0.0.0.0:', '')
|
||||
WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Raw(`UPDATE inbounds
|
||||
SET tag = REPLACE(tag, '0.0.0.0:', '')
|
||||
WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error
|
||||
if err != nil {
|
||||
return
|
||||
if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil {
|
||||
logger.Warningf("VACUUM failed: %v", dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundService) MigrateDB() {
|
||||
s.MigrationRequirements()
|
||||
func (s *InboundService) MigrateDB() error {
|
||||
if err := s.MigrationRequirements(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.MigrationRemoveOrphanedTraffics()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundService) GetOnlineClients() []string {
|
||||
|
|
|
|||
|
|
@ -194,3 +194,61 @@ func TestUpdateInboundClient_DoesNotUpdateOtherInboundTraffic(t *testing.T) {
|
|||
t.Fatalf("expected renamed email to stay isolated to inbound1, got %d rows in inbound2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationRequirements_RollsBackOnAddClientStatFailure(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
|
||||
svc := &InboundService{}
|
||||
inbound := model.Inbound{
|
||||
UserId: 1,
|
||||
Port: 12001,
|
||||
Protocol: model.VLESS,
|
||||
Tag: "rollback-test",
|
||||
Up: 10,
|
||||
Down: 20,
|
||||
Settings: mustMarshalInboundSettings(t, model.Client{
|
||||
ID: "client-rollback",
|
||||
Email: "rollback@example.com",
|
||||
Enable: true,
|
||||
TotalGB: 100,
|
||||
ExpiryTime: 200,
|
||||
}),
|
||||
}
|
||||
if err := database.GetDB().Create(&inbound).Error; err != nil {
|
||||
t.Fatalf("create inbound failed: %v", err)
|
||||
}
|
||||
|
||||
if err := database.GetDB().Exec(`
|
||||
CREATE TRIGGER fail_client_traffic_insert
|
||||
BEFORE INSERT ON client_traffics
|
||||
BEGIN
|
||||
SELECT RAISE(FAIL, 'boom');
|
||||
END;
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("create trigger failed: %v", err)
|
||||
}
|
||||
|
||||
err := svc.MigrationRequirements()
|
||||
if err == nil {
|
||||
t.Fatalf("expected migration requirements to return an error when client traffic insert fails")
|
||||
}
|
||||
|
||||
var refreshed model.Inbound
|
||||
if err := database.GetDB().First(&refreshed, inbound.Id).Error; err != nil {
|
||||
t.Fatalf("reload inbound failed: %v", err)
|
||||
}
|
||||
if refreshed.AllTime != 0 {
|
||||
t.Fatalf("expected inbound all_time rollback to keep 0, got %d", refreshed.AllTime)
|
||||
}
|
||||
|
||||
var traffic xray.ClientTraffic
|
||||
err = database.GetDB().
|
||||
Where("inbound_id = ? AND email = ?", inbound.Id, "rollback@example.com").
|
||||
First(&traffic).Error
|
||||
if err == nil {
|
||||
t.Fatalf("expected client traffic insert to roll back, but row exists: %+v", traffic)
|
||||
}
|
||||
if !database.IsNotFound(err) {
|
||||
t.Fatalf("reload client traffic failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1022,7 +1022,9 @@ func (s *ServerService) ImportDB(file multipart.File) error {
|
|||
return common.NewErrorf("Error migrating db: %v", err)
|
||||
}
|
||||
|
||||
s.inboundService.MigrateDB()
|
||||
if err := s.inboundService.MigrateDB(); err != nil {
|
||||
return common.NewErrorf("Error finalizing imported db: %v", err)
|
||||
}
|
||||
|
||||
// Start Xray
|
||||
if err = s.RestartXrayService(); err != nil {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/mhsanaei/3x-ui/v2/util/reflect_util"
|
||||
"github.com/mhsanaei/3x-ui/v2/web/entity"
|
||||
"github.com/mhsanaei/3x-ui/v2/xray"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
//go:embed config.json
|
||||
|
|
@ -497,19 +498,13 @@ func getXrayTemplateConfigFromDB() (string, error) {
|
|||
// saveXrayTemplateConfigToDB writes xrayTemplateConfig directly to the database.
|
||||
func saveXrayTemplateConfigToDB(value string) error {
|
||||
db := database.GetDB()
|
||||
setting := &model.Setting{}
|
||||
err := db.Model(model.Setting{}).Where("`key` = ?", "xrayTemplateConfig").First(setting).Error
|
||||
if database.IsNotFound(err) {
|
||||
return db.Create(&model.Setting{
|
||||
Key: "xrayTemplateConfig",
|
||||
Value: value,
|
||||
}).Error
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setting.Value = value
|
||||
return db.Save(setting).Error
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.Assignments(map[string]any{"value": value}),
|
||||
}).Create(&model.Setting{
|
||||
Key: "xrayTemplateConfig",
|
||||
Value: value,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// SettingService provides business logic for application settings management.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/mhsanaei/3x-ui/v2/config"
|
||||
"github.com/mhsanaei/3x-ui/v2/database"
|
||||
"github.com/mhsanaei/3x-ui/v2/database/model"
|
||||
)
|
||||
|
||||
func setupTestSettings(t *testing.T) func() {
|
||||
|
|
@ -102,6 +104,31 @@ func TestSettingServiceSetAndGetString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSaveXrayTemplateConfigToDB_UpdatesSingleRow(t *testing.T) {
|
||||
setupTestSettings(t)
|
||||
setupTestDB(t)
|
||||
|
||||
if err := saveXrayTemplateConfigToDB(`{"version":1}`); err != nil {
|
||||
t.Fatalf("first save failed: %v", err)
|
||||
}
|
||||
if err := saveXrayTemplateConfigToDB(`{"version":2}`); err != nil {
|
||||
t.Fatalf("second save failed: %v", err)
|
||||
}
|
||||
|
||||
var settings []model.Setting
|
||||
if err := database.GetDB().
|
||||
Where("key = ?", "xrayTemplateConfig").
|
||||
Find(&settings).Error; err != nil {
|
||||
t.Fatalf("query settings failed: %v", err)
|
||||
}
|
||||
if len(settings) != 1 {
|
||||
t.Fatalf("expected exactly one xrayTemplateConfig row, got %d", len(settings))
|
||||
}
|
||||
if settings[0].Value != `{"version":2}` {
|
||||
t.Fatalf("expected latest config value to be persisted, got %s", settings[0].Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSettingsDeletesFile(t *testing.T) {
|
||||
setupTestSettings(t)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue