From cc6d3daa3a3a6e951c1db855c8808ecb102a12c9 Mon Sep 17 00:00:00 2001 From: Sora39831 <540587985@qq.com> Date: Tue, 7 Apr 2026 02:12:02 +0800 Subject: [PATCH] fix: harden migration and setting writes --- database/db.go | 40 ++++-- database/db_test.go | 44 ++++++ database/model/model.go | 2 +- database/model/model_test.go | 4 +- main.go | 4 +- web/service/inbound.go | 257 ++++++++++++++++++----------------- web/service/inbound_test.go | 58 ++++++++ web/service/server.go | 4 +- web/service/setting.go | 21 ++- web/service/setting_test.go | 27 ++++ 10 files changed, 301 insertions(+), 160 deletions(-) diff --git a/database/db.go b/database/db.go index fd1da3fa..e0c50551 100644 --- a/database/db.go +++ b/database/db.go @@ -91,18 +91,24 @@ func runSeeders(isUsersEmpty bool) error { return err } - if empty && isUsersEmpty { - hashSeeder := &model.HistoryOfSeeders{ - SeederName: "UserPasswordHash", + return db.Transaction(func(tx *gorm.DB) error { + if empty && isUsersEmpty { + hashSeeder := &model.HistoryOfSeeders{ + SeederName: "UserPasswordHash", + } + return tx.Create(hashSeeder).Error } - return db.Create(hashSeeder).Error - } else { + var seedersHistory []string - db.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory) + if err := tx.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory).Error; err != nil { + return err + } if !slices.Contains(seedersHistory, "UserPasswordHash") && !isUsersEmpty { var users []model.User - db.Find(&users) + if err := tx.Find(&users).Error; err != nil { + return err + } for _, user := range users { hashedPassword, err := crypto.HashPasswordAsBcrypt(user.Password) @@ -110,13 +116,15 @@ func runSeeders(isUsersEmpty bool) error { log.Printf("Error hashing password for user '%s': %v", user.Username, err) return err } - db.Model(&user).Update("password", hashedPassword) + if err := tx.Model(&user).Update("password", hashedPassword).Error; err != nil { + return err + } } hashSeeder := &model.HistoryOfSeeders{ SeederName: "UserPasswordHash", } - if err := db.Create(hashSeeder).Error; err != nil { + if err := tx.Create(hashSeeder).Error; err != nil { return err } } @@ -125,21 +133,25 @@ func runSeeders(isUsersEmpty bool) error { // Drop the old unique index on client_traffics.email to allow // the same email across multiple inbounds dbType := config.GetDBTypeFromJSON() + var execErr error if dbType == "mariadb" { - db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics") + execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics").Error } else { - db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email") + execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email").Error + } + if execErr != nil { + return execErr } uniqueSeeder := &model.HistoryOfSeeders{ SeederName: "RemoveClientTrafficEmailUnique", } - if err := db.Create(uniqueSeeder).Error; err != nil { + if err := tx.Create(uniqueSeeder).Error; err != nil { return err } } - } - return nil + return nil + }) } // isTableEmpty returns true if the named table contains zero rows. diff --git a/database/db_test.go b/database/db_test.go index 9cebdd20..a51b5c64 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -212,3 +212,47 @@ func TestInitUser_OnlyOnce(t *testing.T) { t.Errorf("expected 1 user, got %d", count) } } + +func TestRunSeeders_DoesNotRecordHistoryWhenPasswordUpdateFails(t *testing.T) { + setupTestDB(t) + + if err := db.Exec("DELETE FROM history_of_seeders").Error; err != nil { + t.Fatalf("clear seeders history failed: %v", err) + } + + if err := db.Exec(` + CREATE TRIGGER fail_user_password_update + BEFORE UPDATE OF password ON users + BEGIN + SELECT RAISE(FAIL, 'boom'); + END; + `).Error; err != nil { + t.Fatalf("create trigger failed: %v", err) + } + + err := runSeeders(false) + if err == nil { + t.Fatalf("expected runSeeders to fail when user password update fails") + } + + var count int64 + if err := db.Model(&model.HistoryOfSeeders{}). + Where("seeder_name = ?", "UserPasswordHash"). + Count(&count).Error; err != nil { + t.Fatalf("count seeder history failed: %v", err) + } + if count != 0 { + t.Fatalf("expected no UserPasswordHash history row after failed seeder, got %d", count) + } +} + +func TestSettingKey_IsUnique(t *testing.T) { + setupTestDB(t) + + if err := db.Create(&model.Setting{Key: "dup", Value: "one"}).Error; err != nil { + t.Fatalf("first insert failed: %v", err) + } + if err := db.Create(&model.Setting{Key: "dup", Value: "two"}).Error; err == nil { + t.Fatal("expected duplicate setting key insert to fail") + } +} diff --git a/database/model/model.go b/database/model/model.go index 98252f62..36ba1b0f 100644 --- a/database/model/model.go +++ b/database/model/model.go @@ -102,7 +102,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig { // Setting stores key-value configuration settings for the 3x-ui panel. type Setting struct { Id int `json:"id" form:"id" gorm:"primaryKey;autoIncrement"` - Key string `json:"key" form:"key"` + Key string `json:"key" form:"key" gorm:"uniqueIndex"` Value string `json:"value" form:"value"` } diff --git a/database/model/model_test.go b/database/model/model_test.go index 17c9839c..e0d0a2ab 100644 --- a/database/model/model_test.go +++ b/database/model/model_test.go @@ -1,8 +1,6 @@ package model -import ( - "testing" -) +import "testing" func TestGenXrayInboundConfig_EmptyListen(t *testing.T) { in := &Inbound{ diff --git a/main.go b/main.go index 329c5130..ec47b742 100644 --- a/main.go +++ b/main.go @@ -411,7 +411,9 @@ func migrateDb() { log.Fatal(err) } fmt.Println("Start migrating database...") - inboundService.MigrateDB() + if err := inboundService.MigrateDB(); err != nil { + log.Fatal(err) + } fmt.Println("Migration done!") } diff --git a/web/service/inbound.go b/web/service/inbound.go index 83362259..68c1cc42 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -2333,167 +2333,170 @@ func (s *InboundService) SearchInbounds(query string) ([]*model.Inbound, error) return inbounds, nil } -func (s *InboundService) MigrationRequirements() { +func (s *InboundService) MigrationRequirements() error { db := database.GetDB() - tx := db.Begin() - var err error - defer func() { - if err == nil { - tx.Commit() - if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil { - logger.Warningf("VACUUM failed: %v", dbErr) - } - } else { - tx.Rollback() + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec(` + UPDATE inbounds + SET all_time = IFNULL(up, 0) + IFNULL(down, 0) + WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0 + `).Error; err != nil { + return err + } + if err := tx.Exec(` + UPDATE client_traffics + SET all_time = IFNULL(up, 0) + IFNULL(down, 0) + WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0 + `).Error; err != nil { + return err } - }() - // Calculate and backfill all_time from up+down for inbounds and clients - err = tx.Exec(` - UPDATE inbounds - SET all_time = IFNULL(up, 0) + IFNULL(down, 0) - WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0 - `).Error - if err != nil { - return - } - err = tx.Exec(` - UPDATE client_traffics - SET all_time = IFNULL(up, 0) + IFNULL(down, 0) - WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0 - `).Error + var inbounds []*model.Inbound + err := tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error + if err != nil && err != gorm.ErrRecordNotFound { + return err + } + for inboundIndex := range inbounds { + settings := map[string]any{} + if err := json.Unmarshal([]byte(inbounds[inboundIndex].Settings), &settings); err != nil { + return err + } + clients, ok := settings["clients"].([]any) + if ok { + var newClients []any + for clientIndex := range clients { + c, ok := clients[clientIndex].(map[string]any) + if !ok { + return fmt.Errorf("invalid client settings format for inbound %d", inbounds[inboundIndex].Id) + } - if err != nil { - return - } - - // Fix inbounds based problems - var inbounds []*model.Inbound - err = tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error - if err != nil && err != gorm.ErrRecordNotFound { - return - } - for inbound_index := range inbounds { - settings := map[string]any{} - json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings) - clients, ok := settings["clients"].([]any) - if ok { - // Fix Client configuration problems - var newClients []any - for client_index := range clients { - c := clients[client_index].(map[string]any) - - // Add email='' if it is not exists - if _, ok := c["email"]; !ok { - c["email"] = "" - } - - // Convert string tgId to int64 - if _, ok := c["tgId"]; ok { - var tgId any = c["tgId"] - if tgIdStr, ok2 := tgId.(string); ok2 { - tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64) - if err == nil { - c["tgId"] = tgIdInt64 + if _, ok := c["email"]; !ok { + c["email"] = "" + } + if _, ok := c["tgId"]; ok { + tgId := c["tgId"] + if tgIdStr, ok2 := tgId.(string); ok2 { + tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64) + if err == nil { + c["tgId"] = tgIdInt64 + } } } - } - - // Remove "flow": "xtls-rprx-direct" - if _, ok := c["flow"]; ok { - if c["flow"] == "xtls-rprx-direct" { + if _, ok := c["flow"]; ok && c["flow"] == "xtls-rprx-direct" { c["flow"] = "" } + if _, ok := c["created_at"]; !ok { + c["created_at"] = time.Now().Unix() * 1000 + } + c["updated_at"] = time.Now().Unix() * 1000 + newClients = append(newClients, c) } - // Backfill created_at and updated_at - if _, ok := c["created_at"]; !ok { - c["created_at"] = time.Now().Unix() * 1000 + settings["clients"] = newClients + modifiedSettings, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err } - c["updated_at"] = time.Now().Unix() * 1000 - newClients = append(newClients, any(c)) + inbounds[inboundIndex].Settings = string(modifiedSettings) } - settings["clients"] = newClients - modifiedSettings, err := json.MarshalIndent(settings, "", " ") + + modelClients, err := s.GetClients(inbounds[inboundIndex]) if err != nil { - return + return err } - - inbounds[inbound_index].Settings = string(modifiedSettings) - } - - // Add client traffic row for all clients which has email - modelClients, err := s.GetClients(inbounds[inbound_index]) - if err != nil { - return - } - for _, modelClient := range modelClients { - if len(modelClient.Email) > 0 { + for _, modelClient := range modelClients { + if len(modelClient.Email) == 0 { + continue + } var count int64 - tx.Model(xray.ClientTraffic{}). - Where("inbound_id = ? AND email = ?", inbounds[inbound_index].Id, modelClient.Email). - Count(&count) + if err := tx.Model(xray.ClientTraffic{}). + Where("inbound_id = ? AND email = ?", inbounds[inboundIndex].Id, modelClient.Email). + Count(&count).Error; err != nil { + return err + } if count == 0 { - s.AddClientStat(tx, inbounds[inbound_index].Id, &modelClient) + if err := s.AddClientStat(tx, inbounds[inboundIndex].Id, &modelClient); err != nil { + return err + } } } } - } - tx.Save(inbounds) + if err := tx.Save(inbounds).Error; err != nil { + return err + } + if err := tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{}).Error; err != nil { + return err + } - // Remove orphaned traffics - tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{}) + var externalProxy []struct { + Id int + Port int + StreamSettings []byte + } + if err := tx.Raw(`select id, port, stream_settings + from inbounds + WHERE protocol in ('vmess','vless','trojan') + AND json_extract(stream_settings, '$.security') = 'tls' + AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error; err != nil { + return err + } - // Migrate old MultiDomain to External Proxy - var externalProxy []struct { - Id int - Port int - StreamSettings []byte - } - err = tx.Raw(`select id, port, stream_settings - from inbounds - WHERE protocol in ('vmess','vless','trojan') - AND json_extract(stream_settings, '$.security') = 'tls' - AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error - if err != nil || len(externalProxy) == 0 { - return - } - - for _, ep := range externalProxy { - var reverses any - var stream map[string]any - json.Unmarshal(ep.StreamSettings, &stream) - if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok { - if settings, ok := tlsSettings["settings"].(map[string]any); ok { - if domains, ok := settings["domains"].([]any); ok { - for _, domain := range domains { - if domainMap, ok := domain.(map[string]any); ok { + for _, ep := range externalProxy { + var reverses any + var stream map[string]any + if err := json.Unmarshal(ep.StreamSettings, &stream); err != nil { + return err + } + if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok { + if settings, ok := tlsSettings["settings"].(map[string]any); ok { + if domains, ok := settings["domains"].([]any); ok { + for _, domain := range domains { + domainMap, ok := domain.(map[string]any) + if !ok { + return fmt.Errorf("invalid tls domain settings format for inbound %d", ep.Id) + } + domainName, ok := domainMap["domain"].(string) + if !ok { + return fmt.Errorf("invalid tls domain name for inbound %d", ep.Id) + } domainMap["forceTls"] = "same" domainMap["port"] = ep.Port - domainMap["dest"] = domainMap["domain"].(string) + domainMap["dest"] = domainName delete(domainMap, "domain") } } + reverses = settings["domains"] + delete(settings, "domains") } - reverses = settings["domains"] - delete(settings, "domains") + } + stream["externalProxy"] = reverses + newStream, err := json.MarshalIndent(stream, " ", " ") + if err != nil { + return err + } + if err := tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream).Error; err != nil { + return err } } - stream["externalProxy"] = reverses - newStream, _ := json.MarshalIndent(stream, " ", " ") - tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream) + + return tx.Raw(`UPDATE inbounds + SET tag = REPLACE(tag, '0.0.0.0:', '') + WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error + }); err != nil { + return err } - err = tx.Raw(`UPDATE inbounds - SET tag = REPLACE(tag, '0.0.0.0:', '') - WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error - if err != nil { - return + if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil { + logger.Warningf("VACUUM failed: %v", dbErr) } + return nil } -func (s *InboundService) MigrateDB() { - s.MigrationRequirements() +func (s *InboundService) MigrateDB() error { + if err := s.MigrationRequirements(); err != nil { + return err + } s.MigrationRemoveOrphanedTraffics() + return nil } func (s *InboundService) GetOnlineClients() []string { diff --git a/web/service/inbound_test.go b/web/service/inbound_test.go index b78a0f7a..20dee7b5 100644 --- a/web/service/inbound_test.go +++ b/web/service/inbound_test.go @@ -194,3 +194,61 @@ func TestUpdateInboundClient_DoesNotUpdateOtherInboundTraffic(t *testing.T) { t.Fatalf("expected renamed email to stay isolated to inbound1, got %d rows in inbound2", got) } } + +func TestMigrationRequirements_RollsBackOnAddClientStatFailure(t *testing.T) { + setupTestDB(t) + + svc := &InboundService{} + inbound := model.Inbound{ + UserId: 1, + Port: 12001, + Protocol: model.VLESS, + Tag: "rollback-test", + Up: 10, + Down: 20, + Settings: mustMarshalInboundSettings(t, model.Client{ + ID: "client-rollback", + Email: "rollback@example.com", + Enable: true, + TotalGB: 100, + ExpiryTime: 200, + }), + } + if err := database.GetDB().Create(&inbound).Error; err != nil { + t.Fatalf("create inbound failed: %v", err) + } + + if err := database.GetDB().Exec(` + CREATE TRIGGER fail_client_traffic_insert + BEFORE INSERT ON client_traffics + BEGIN + SELECT RAISE(FAIL, 'boom'); + END; + `).Error; err != nil { + t.Fatalf("create trigger failed: %v", err) + } + + err := svc.MigrationRequirements() + if err == nil { + t.Fatalf("expected migration requirements to return an error when client traffic insert fails") + } + + var refreshed model.Inbound + if err := database.GetDB().First(&refreshed, inbound.Id).Error; err != nil { + t.Fatalf("reload inbound failed: %v", err) + } + if refreshed.AllTime != 0 { + t.Fatalf("expected inbound all_time rollback to keep 0, got %d", refreshed.AllTime) + } + + var traffic xray.ClientTraffic + err = database.GetDB(). + Where("inbound_id = ? AND email = ?", inbound.Id, "rollback@example.com"). + First(&traffic).Error + if err == nil { + t.Fatalf("expected client traffic insert to roll back, but row exists: %+v", traffic) + } + if !database.IsNotFound(err) { + t.Fatalf("reload client traffic failed: %v", err) + } +} diff --git a/web/service/server.go b/web/service/server.go index 96c4174f..8cb895a3 100644 --- a/web/service/server.go +++ b/web/service/server.go @@ -1022,7 +1022,9 @@ func (s *ServerService) ImportDB(file multipart.File) error { return common.NewErrorf("Error migrating db: %v", err) } - s.inboundService.MigrateDB() + if err := s.inboundService.MigrateDB(); err != nil { + return common.NewErrorf("Error finalizing imported db: %v", err) + } // Start Xray if err = s.RestartXrayService(); err != nil { diff --git a/web/service/setting.go b/web/service/setting.go index ce855791..e5a3a3d5 100644 --- a/web/service/setting.go +++ b/web/service/setting.go @@ -21,6 +21,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/util/reflect_util" "github.com/mhsanaei/3x-ui/v2/web/entity" "github.com/mhsanaei/3x-ui/v2/xray" + "gorm.io/gorm/clause" ) //go:embed config.json @@ -497,19 +498,13 @@ func getXrayTemplateConfigFromDB() (string, error) { // saveXrayTemplateConfigToDB writes xrayTemplateConfig directly to the database. func saveXrayTemplateConfigToDB(value string) error { db := database.GetDB() - setting := &model.Setting{} - err := db.Model(model.Setting{}).Where("`key` = ?", "xrayTemplateConfig").First(setting).Error - if database.IsNotFound(err) { - return db.Create(&model.Setting{ - Key: "xrayTemplateConfig", - Value: value, - }).Error - } - if err != nil { - return err - } - setting.Value = value - return db.Save(setting).Error + return db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]any{"value": value}), + }).Create(&model.Setting{ + Key: "xrayTemplateConfig", + Value: value, + }).Error } // SettingService provides business logic for application settings management. diff --git a/web/service/setting_test.go b/web/service/setting_test.go index 2a656840..93ed9d03 100644 --- a/web/service/setting_test.go +++ b/web/service/setting_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/mhsanaei/3x-ui/v2/config" + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" ) func setupTestSettings(t *testing.T) func() { @@ -102,6 +104,31 @@ func TestSettingServiceSetAndGetString(t *testing.T) { } } +func TestSaveXrayTemplateConfigToDB_UpdatesSingleRow(t *testing.T) { + setupTestSettings(t) + setupTestDB(t) + + if err := saveXrayTemplateConfigToDB(`{"version":1}`); err != nil { + t.Fatalf("first save failed: %v", err) + } + if err := saveXrayTemplateConfigToDB(`{"version":2}`); err != nil { + t.Fatalf("second save failed: %v", err) + } + + var settings []model.Setting + if err := database.GetDB(). + Where("key = ?", "xrayTemplateConfig"). + Find(&settings).Error; err != nil { + t.Fatalf("query settings failed: %v", err) + } + if len(settings) != 1 { + t.Fatalf("expected exactly one xrayTemplateConfig row, got %d", len(settings)) + } + if settings[0].Value != `{"version":2}` { + t.Fatalf("expected latest config value to be persisted, got %s", settings[0].Value) + } +} + func TestResetSettingsDeletesFile(t *testing.T) { setupTestSettings(t)