fix: harden migration and setting writes

This commit is contained in:
Sora39831 2026-04-07 02:12:02 +08:00
parent e298996d77
commit cc6d3daa3a
10 changed files with 301 additions and 160 deletions

View file

@ -91,18 +91,24 @@ func runSeeders(isUsersEmpty bool) error {
return err
}
if empty && isUsersEmpty {
hashSeeder := &model.HistoryOfSeeders{
SeederName: "UserPasswordHash",
return db.Transaction(func(tx *gorm.DB) error {
if empty && isUsersEmpty {
hashSeeder := &model.HistoryOfSeeders{
SeederName: "UserPasswordHash",
}
return tx.Create(hashSeeder).Error
}
return db.Create(hashSeeder).Error
} else {
var seedersHistory []string
db.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory)
if err := tx.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory).Error; err != nil {
return err
}
if !slices.Contains(seedersHistory, "UserPasswordHash") && !isUsersEmpty {
var users []model.User
db.Find(&users)
if err := tx.Find(&users).Error; err != nil {
return err
}
for _, user := range users {
hashedPassword, err := crypto.HashPasswordAsBcrypt(user.Password)
@ -110,13 +116,15 @@ func runSeeders(isUsersEmpty bool) error {
log.Printf("Error hashing password for user '%s': %v", user.Username, err)
return err
}
db.Model(&user).Update("password", hashedPassword)
if err := tx.Model(&user).Update("password", hashedPassword).Error; err != nil {
return err
}
}
hashSeeder := &model.HistoryOfSeeders{
SeederName: "UserPasswordHash",
}
if err := db.Create(hashSeeder).Error; err != nil {
if err := tx.Create(hashSeeder).Error; err != nil {
return err
}
}
@ -125,21 +133,25 @@ func runSeeders(isUsersEmpty bool) error {
// Drop the old unique index on client_traffics.email to allow
// the same email across multiple inbounds
dbType := config.GetDBTypeFromJSON()
var execErr error
if dbType == "mariadb" {
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics")
execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics").Error
} else {
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email")
execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email").Error
}
if execErr != nil {
return execErr
}
uniqueSeeder := &model.HistoryOfSeeders{
SeederName: "RemoveClientTrafficEmailUnique",
}
if err := db.Create(uniqueSeeder).Error; err != nil {
if err := tx.Create(uniqueSeeder).Error; err != nil {
return err
}
}
}
return nil
return nil
})
}
// isTableEmpty returns true if the named table contains zero rows.

View file

@ -212,3 +212,47 @@ func TestInitUser_OnlyOnce(t *testing.T) {
t.Errorf("expected 1 user, got %d", count)
}
}
func TestRunSeeders_DoesNotRecordHistoryWhenPasswordUpdateFails(t *testing.T) {
setupTestDB(t)
if err := db.Exec("DELETE FROM history_of_seeders").Error; err != nil {
t.Fatalf("clear seeders history failed: %v", err)
}
if err := db.Exec(`
CREATE TRIGGER fail_user_password_update
BEFORE UPDATE OF password ON users
BEGIN
SELECT RAISE(FAIL, 'boom');
END;
`).Error; err != nil {
t.Fatalf("create trigger failed: %v", err)
}
err := runSeeders(false)
if err == nil {
t.Fatalf("expected runSeeders to fail when user password update fails")
}
var count int64
if err := db.Model(&model.HistoryOfSeeders{}).
Where("seeder_name = ?", "UserPasswordHash").
Count(&count).Error; err != nil {
t.Fatalf("count seeder history failed: %v", err)
}
if count != 0 {
t.Fatalf("expected no UserPasswordHash history row after failed seeder, got %d", count)
}
}
func TestSettingKey_IsUnique(t *testing.T) {
setupTestDB(t)
if err := db.Create(&model.Setting{Key: "dup", Value: "one"}).Error; err != nil {
t.Fatalf("first insert failed: %v", err)
}
if err := db.Create(&model.Setting{Key: "dup", Value: "two"}).Error; err == nil {
t.Fatal("expected duplicate setting key insert to fail")
}
}

View file

@ -102,7 +102,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig {
// Setting stores key-value configuration settings for the 3x-ui panel.
type Setting struct {
Id int `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
Key string `json:"key" form:"key"`
Key string `json:"key" form:"key" gorm:"uniqueIndex"`
Value string `json:"value" form:"value"`
}

View file

@ -1,8 +1,6 @@
package model
import (
"testing"
)
import "testing"
func TestGenXrayInboundConfig_EmptyListen(t *testing.T) {
in := &Inbound{

View file

@ -411,7 +411,9 @@ func migrateDb() {
log.Fatal(err)
}
fmt.Println("Start migrating database...")
inboundService.MigrateDB()
if err := inboundService.MigrateDB(); err != nil {
log.Fatal(err)
}
fmt.Println("Migration done!")
}

View file

@ -2333,167 +2333,170 @@ func (s *InboundService) SearchInbounds(query string) ([]*model.Inbound, error)
return inbounds, nil
}
func (s *InboundService) MigrationRequirements() {
func (s *InboundService) MigrationRequirements() error {
db := database.GetDB()
tx := db.Begin()
var err error
defer func() {
if err == nil {
tx.Commit()
if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil {
logger.Warningf("VACUUM failed: %v", dbErr)
}
} else {
tx.Rollback()
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(`
UPDATE inbounds
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
`).Error; err != nil {
return err
}
if err := tx.Exec(`
UPDATE client_traffics
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
`).Error; err != nil {
return err
}
}()
// Calculate and backfill all_time from up+down for inbounds and clients
err = tx.Exec(`
UPDATE inbounds
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
`).Error
if err != nil {
return
}
err = tx.Exec(`
UPDATE client_traffics
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
`).Error
var inbounds []*model.Inbound
err := tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error
if err != nil && err != gorm.ErrRecordNotFound {
return err
}
for inboundIndex := range inbounds {
settings := map[string]any{}
if err := json.Unmarshal([]byte(inbounds[inboundIndex].Settings), &settings); err != nil {
return err
}
clients, ok := settings["clients"].([]any)
if ok {
var newClients []any
for clientIndex := range clients {
c, ok := clients[clientIndex].(map[string]any)
if !ok {
return fmt.Errorf("invalid client settings format for inbound %d", inbounds[inboundIndex].Id)
}
if err != nil {
return
}
// Fix inbounds based problems
var inbounds []*model.Inbound
err = tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error
if err != nil && err != gorm.ErrRecordNotFound {
return
}
for inbound_index := range inbounds {
settings := map[string]any{}
json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings)
clients, ok := settings["clients"].([]any)
if ok {
// Fix Client configuration problems
var newClients []any
for client_index := range clients {
c := clients[client_index].(map[string]any)
// Add email='' if it is not exists
if _, ok := c["email"]; !ok {
c["email"] = ""
}
// Convert string tgId to int64
if _, ok := c["tgId"]; ok {
var tgId any = c["tgId"]
if tgIdStr, ok2 := tgId.(string); ok2 {
tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64)
if err == nil {
c["tgId"] = tgIdInt64
if _, ok := c["email"]; !ok {
c["email"] = ""
}
if _, ok := c["tgId"]; ok {
tgId := c["tgId"]
if tgIdStr, ok2 := tgId.(string); ok2 {
tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64)
if err == nil {
c["tgId"] = tgIdInt64
}
}
}
}
// Remove "flow": "xtls-rprx-direct"
if _, ok := c["flow"]; ok {
if c["flow"] == "xtls-rprx-direct" {
if _, ok := c["flow"]; ok && c["flow"] == "xtls-rprx-direct" {
c["flow"] = ""
}
if _, ok := c["created_at"]; !ok {
c["created_at"] = time.Now().Unix() * 1000
}
c["updated_at"] = time.Now().Unix() * 1000
newClients = append(newClients, c)
}
// Backfill created_at and updated_at
if _, ok := c["created_at"]; !ok {
c["created_at"] = time.Now().Unix() * 1000
settings["clients"] = newClients
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
c["updated_at"] = time.Now().Unix() * 1000
newClients = append(newClients, any(c))
inbounds[inboundIndex].Settings = string(modifiedSettings)
}
settings["clients"] = newClients
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
modelClients, err := s.GetClients(inbounds[inboundIndex])
if err != nil {
return
return err
}
inbounds[inbound_index].Settings = string(modifiedSettings)
}
// Add client traffic row for all clients which has email
modelClients, err := s.GetClients(inbounds[inbound_index])
if err != nil {
return
}
for _, modelClient := range modelClients {
if len(modelClient.Email) > 0 {
for _, modelClient := range modelClients {
if len(modelClient.Email) == 0 {
continue
}
var count int64
tx.Model(xray.ClientTraffic{}).
Where("inbound_id = ? AND email = ?", inbounds[inbound_index].Id, modelClient.Email).
Count(&count)
if err := tx.Model(xray.ClientTraffic{}).
Where("inbound_id = ? AND email = ?", inbounds[inboundIndex].Id, modelClient.Email).
Count(&count).Error; err != nil {
return err
}
if count == 0 {
s.AddClientStat(tx, inbounds[inbound_index].Id, &modelClient)
if err := s.AddClientStat(tx, inbounds[inboundIndex].Id, &modelClient); err != nil {
return err
}
}
}
}
}
tx.Save(inbounds)
if err := tx.Save(inbounds).Error; err != nil {
return err
}
if err := tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{}).Error; err != nil {
return err
}
// Remove orphaned traffics
tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{})
var externalProxy []struct {
Id int
Port int
StreamSettings []byte
}
if err := tx.Raw(`select id, port, stream_settings
from inbounds
WHERE protocol in ('vmess','vless','trojan')
AND json_extract(stream_settings, '$.security') = 'tls'
AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error; err != nil {
return err
}
// Migrate old MultiDomain to External Proxy
var externalProxy []struct {
Id int
Port int
StreamSettings []byte
}
err = tx.Raw(`select id, port, stream_settings
from inbounds
WHERE protocol in ('vmess','vless','trojan')
AND json_extract(stream_settings, '$.security') = 'tls'
AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error
if err != nil || len(externalProxy) == 0 {
return
}
for _, ep := range externalProxy {
var reverses any
var stream map[string]any
json.Unmarshal(ep.StreamSettings, &stream)
if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok {
if settings, ok := tlsSettings["settings"].(map[string]any); ok {
if domains, ok := settings["domains"].([]any); ok {
for _, domain := range domains {
if domainMap, ok := domain.(map[string]any); ok {
for _, ep := range externalProxy {
var reverses any
var stream map[string]any
if err := json.Unmarshal(ep.StreamSettings, &stream); err != nil {
return err
}
if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok {
if settings, ok := tlsSettings["settings"].(map[string]any); ok {
if domains, ok := settings["domains"].([]any); ok {
for _, domain := range domains {
domainMap, ok := domain.(map[string]any)
if !ok {
return fmt.Errorf("invalid tls domain settings format for inbound %d", ep.Id)
}
domainName, ok := domainMap["domain"].(string)
if !ok {
return fmt.Errorf("invalid tls domain name for inbound %d", ep.Id)
}
domainMap["forceTls"] = "same"
domainMap["port"] = ep.Port
domainMap["dest"] = domainMap["domain"].(string)
domainMap["dest"] = domainName
delete(domainMap, "domain")
}
}
reverses = settings["domains"]
delete(settings, "domains")
}
reverses = settings["domains"]
delete(settings, "domains")
}
stream["externalProxy"] = reverses
newStream, err := json.MarshalIndent(stream, " ", " ")
if err != nil {
return err
}
if err := tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream).Error; err != nil {
return err
}
}
stream["externalProxy"] = reverses
newStream, _ := json.MarshalIndent(stream, " ", " ")
tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream)
return tx.Raw(`UPDATE inbounds
SET tag = REPLACE(tag, '0.0.0.0:', '')
WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error
}); err != nil {
return err
}
err = tx.Raw(`UPDATE inbounds
SET tag = REPLACE(tag, '0.0.0.0:', '')
WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error
if err != nil {
return
if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil {
logger.Warningf("VACUUM failed: %v", dbErr)
}
return nil
}
func (s *InboundService) MigrateDB() {
s.MigrationRequirements()
func (s *InboundService) MigrateDB() error {
if err := s.MigrationRequirements(); err != nil {
return err
}
s.MigrationRemoveOrphanedTraffics()
return nil
}
func (s *InboundService) GetOnlineClients() []string {

View file

@ -194,3 +194,61 @@ func TestUpdateInboundClient_DoesNotUpdateOtherInboundTraffic(t *testing.T) {
t.Fatalf("expected renamed email to stay isolated to inbound1, got %d rows in inbound2", got)
}
}
func TestMigrationRequirements_RollsBackOnAddClientStatFailure(t *testing.T) {
setupTestDB(t)
svc := &InboundService{}
inbound := model.Inbound{
UserId: 1,
Port: 12001,
Protocol: model.VLESS,
Tag: "rollback-test",
Up: 10,
Down: 20,
Settings: mustMarshalInboundSettings(t, model.Client{
ID: "client-rollback",
Email: "rollback@example.com",
Enable: true,
TotalGB: 100,
ExpiryTime: 200,
}),
}
if err := database.GetDB().Create(&inbound).Error; err != nil {
t.Fatalf("create inbound failed: %v", err)
}
if err := database.GetDB().Exec(`
CREATE TRIGGER fail_client_traffic_insert
BEFORE INSERT ON client_traffics
BEGIN
SELECT RAISE(FAIL, 'boom');
END;
`).Error; err != nil {
t.Fatalf("create trigger failed: %v", err)
}
err := svc.MigrationRequirements()
if err == nil {
t.Fatalf("expected migration requirements to return an error when client traffic insert fails")
}
var refreshed model.Inbound
if err := database.GetDB().First(&refreshed, inbound.Id).Error; err != nil {
t.Fatalf("reload inbound failed: %v", err)
}
if refreshed.AllTime != 0 {
t.Fatalf("expected inbound all_time rollback to keep 0, got %d", refreshed.AllTime)
}
var traffic xray.ClientTraffic
err = database.GetDB().
Where("inbound_id = ? AND email = ?", inbound.Id, "rollback@example.com").
First(&traffic).Error
if err == nil {
t.Fatalf("expected client traffic insert to roll back, but row exists: %+v", traffic)
}
if !database.IsNotFound(err) {
t.Fatalf("reload client traffic failed: %v", err)
}
}

View file

@ -1022,7 +1022,9 @@ func (s *ServerService) ImportDB(file multipart.File) error {
return common.NewErrorf("Error migrating db: %v", err)
}
s.inboundService.MigrateDB()
if err := s.inboundService.MigrateDB(); err != nil {
return common.NewErrorf("Error finalizing imported db: %v", err)
}
// Start Xray
if err = s.RestartXrayService(); err != nil {

View file

@ -21,6 +21,7 @@ import (
"github.com/mhsanaei/3x-ui/v2/util/reflect_util"
"github.com/mhsanaei/3x-ui/v2/web/entity"
"github.com/mhsanaei/3x-ui/v2/xray"
"gorm.io/gorm/clause"
)
//go:embed config.json
@ -497,19 +498,13 @@ func getXrayTemplateConfigFromDB() (string, error) {
// saveXrayTemplateConfigToDB writes xrayTemplateConfig directly to the database.
func saveXrayTemplateConfigToDB(value string) error {
db := database.GetDB()
setting := &model.Setting{}
err := db.Model(model.Setting{}).Where("`key` = ?", "xrayTemplateConfig").First(setting).Error
if database.IsNotFound(err) {
return db.Create(&model.Setting{
Key: "xrayTemplateConfig",
Value: value,
}).Error
}
if err != nil {
return err
}
setting.Value = value
return db.Save(setting).Error
return db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.Assignments(map[string]any{"value": value}),
}).Create(&model.Setting{
Key: "xrayTemplateConfig",
Value: value,
}).Error
}
// SettingService provides business logic for application settings management.

View file

@ -6,6 +6,8 @@ import (
"testing"
"github.com/mhsanaei/3x-ui/v2/config"
"github.com/mhsanaei/3x-ui/v2/database"
"github.com/mhsanaei/3x-ui/v2/database/model"
)
func setupTestSettings(t *testing.T) func() {
@ -102,6 +104,31 @@ func TestSettingServiceSetAndGetString(t *testing.T) {
}
}
func TestSaveXrayTemplateConfigToDB_UpdatesSingleRow(t *testing.T) {
setupTestSettings(t)
setupTestDB(t)
if err := saveXrayTemplateConfigToDB(`{"version":1}`); err != nil {
t.Fatalf("first save failed: %v", err)
}
if err := saveXrayTemplateConfigToDB(`{"version":2}`); err != nil {
t.Fatalf("second save failed: %v", err)
}
var settings []model.Setting
if err := database.GetDB().
Where("key = ?", "xrayTemplateConfig").
Find(&settings).Error; err != nil {
t.Fatalf("query settings failed: %v", err)
}
if len(settings) != 1 {
t.Fatalf("expected exactly one xrayTemplateConfig row, got %d", len(settings))
}
if settings[0].Value != `{"version":2}` {
t.Fatalf("expected latest config value to be persisted, got %s", settings[0].Value)
}
}
func TestResetSettingsDeletesFile(t *testing.T) {
setupTestSettings(t)