mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-07 13:44:24 +00:00
fix: harden migration and setting writes
This commit is contained in:
parent
e298996d77
commit
cc6d3daa3a
10 changed files with 301 additions and 160 deletions
|
|
@ -91,18 +91,24 @@ func runSeeders(isUsersEmpty bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if empty && isUsersEmpty {
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
hashSeeder := &model.HistoryOfSeeders{
|
if empty && isUsersEmpty {
|
||||||
SeederName: "UserPasswordHash",
|
hashSeeder := &model.HistoryOfSeeders{
|
||||||
|
SeederName: "UserPasswordHash",
|
||||||
|
}
|
||||||
|
return tx.Create(hashSeeder).Error
|
||||||
}
|
}
|
||||||
return db.Create(hashSeeder).Error
|
|
||||||
} else {
|
|
||||||
var seedersHistory []string
|
var seedersHistory []string
|
||||||
db.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory)
|
if err := tx.Model(&model.HistoryOfSeeders{}).Pluck("seeder_name", &seedersHistory).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if !slices.Contains(seedersHistory, "UserPasswordHash") && !isUsersEmpty {
|
if !slices.Contains(seedersHistory, "UserPasswordHash") && !isUsersEmpty {
|
||||||
var users []model.User
|
var users []model.User
|
||||||
db.Find(&users)
|
if err := tx.Find(&users).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
hashedPassword, err := crypto.HashPasswordAsBcrypt(user.Password)
|
hashedPassword, err := crypto.HashPasswordAsBcrypt(user.Password)
|
||||||
|
|
@ -110,13 +116,15 @@ func runSeeders(isUsersEmpty bool) error {
|
||||||
log.Printf("Error hashing password for user '%s': %v", user.Username, err)
|
log.Printf("Error hashing password for user '%s': %v", user.Username, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
db.Model(&user).Update("password", hashedPassword)
|
if err := tx.Model(&user).Update("password", hashedPassword).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hashSeeder := &model.HistoryOfSeeders{
|
hashSeeder := &model.HistoryOfSeeders{
|
||||||
SeederName: "UserPasswordHash",
|
SeederName: "UserPasswordHash",
|
||||||
}
|
}
|
||||||
if err := db.Create(hashSeeder).Error; err != nil {
|
if err := tx.Create(hashSeeder).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -125,21 +133,25 @@ func runSeeders(isUsersEmpty bool) error {
|
||||||
// Drop the old unique index on client_traffics.email to allow
|
// Drop the old unique index on client_traffics.email to allow
|
||||||
// the same email across multiple inbounds
|
// the same email across multiple inbounds
|
||||||
dbType := config.GetDBTypeFromJSON()
|
dbType := config.GetDBTypeFromJSON()
|
||||||
|
var execErr error
|
||||||
if dbType == "mariadb" {
|
if dbType == "mariadb" {
|
||||||
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics")
|
execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics").Error
|
||||||
} else {
|
} else {
|
||||||
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email")
|
execErr = tx.Exec("DROP INDEX IF EXISTS idx_client_traffics_email").Error
|
||||||
|
}
|
||||||
|
if execErr != nil {
|
||||||
|
return execErr
|
||||||
}
|
}
|
||||||
uniqueSeeder := &model.HistoryOfSeeders{
|
uniqueSeeder := &model.HistoryOfSeeders{
|
||||||
SeederName: "RemoveClientTrafficEmailUnique",
|
SeederName: "RemoveClientTrafficEmailUnique",
|
||||||
}
|
}
|
||||||
if err := db.Create(uniqueSeeder).Error; err != nil {
|
if err := tx.Create(uniqueSeeder).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTableEmpty returns true if the named table contains zero rows.
|
// isTableEmpty returns true if the named table contains zero rows.
|
||||||
|
|
|
||||||
|
|
@ -212,3 +212,47 @@ func TestInitUser_OnlyOnce(t *testing.T) {
|
||||||
t.Errorf("expected 1 user, got %d", count)
|
t.Errorf("expected 1 user, got %d", count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunSeeders_DoesNotRecordHistoryWhenPasswordUpdateFails(t *testing.T) {
|
||||||
|
setupTestDB(t)
|
||||||
|
|
||||||
|
if err := db.Exec("DELETE FROM history_of_seeders").Error; err != nil {
|
||||||
|
t.Fatalf("clear seeders history failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Exec(`
|
||||||
|
CREATE TRIGGER fail_user_password_update
|
||||||
|
BEFORE UPDATE OF password ON users
|
||||||
|
BEGIN
|
||||||
|
SELECT RAISE(FAIL, 'boom');
|
||||||
|
END;
|
||||||
|
`).Error; err != nil {
|
||||||
|
t.Fatalf("create trigger failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runSeeders(false)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected runSeeders to fail when user password update fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
if err := db.Model(&model.HistoryOfSeeders{}).
|
||||||
|
Where("seeder_name = ?", "UserPasswordHash").
|
||||||
|
Count(&count).Error; err != nil {
|
||||||
|
t.Fatalf("count seeder history failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatalf("expected no UserPasswordHash history row after failed seeder, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingKey_IsUnique(t *testing.T) {
|
||||||
|
setupTestDB(t)
|
||||||
|
|
||||||
|
if err := db.Create(&model.Setting{Key: "dup", Value: "one"}).Error; err != nil {
|
||||||
|
t.Fatalf("first insert failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Create(&model.Setting{Key: "dup", Value: "two"}).Error; err == nil {
|
||||||
|
t.Fatal("expected duplicate setting key insert to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig {
|
||||||
// Setting stores key-value configuration settings for the 3x-ui panel.
|
// Setting stores key-value configuration settings for the 3x-ui panel.
|
||||||
type Setting struct {
|
type Setting struct {
|
||||||
Id int `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
|
Id int `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
|
||||||
Key string `json:"key" form:"key"`
|
Key string `json:"key" form:"key" gorm:"uniqueIndex"`
|
||||||
Value string `json:"value" form:"value"`
|
Value string `json:"value" form:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import "testing"
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGenXrayInboundConfig_EmptyListen(t *testing.T) {
|
func TestGenXrayInboundConfig_EmptyListen(t *testing.T) {
|
||||||
in := &Inbound{
|
in := &Inbound{
|
||||||
|
|
|
||||||
4
main.go
4
main.go
|
|
@ -411,7 +411,9 @@ func migrateDb() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
fmt.Println("Start migrating database...")
|
fmt.Println("Start migrating database...")
|
||||||
inboundService.MigrateDB()
|
if err := inboundService.MigrateDB(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
fmt.Println("Migration done!")
|
fmt.Println("Migration done!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2333,167 +2333,170 @@ func (s *InboundService) SearchInbounds(query string) ([]*model.Inbound, error)
|
||||||
return inbounds, nil
|
return inbounds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) MigrationRequirements() {
|
func (s *InboundService) MigrationRequirements() error {
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
tx := db.Begin()
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
var err error
|
if err := tx.Exec(`
|
||||||
defer func() {
|
UPDATE inbounds
|
||||||
if err == nil {
|
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
||||||
tx.Commit()
|
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
||||||
if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil {
|
`).Error; err != nil {
|
||||||
logger.Warningf("VACUUM failed: %v", dbErr)
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
if err := tx.Exec(`
|
||||||
tx.Rollback()
|
UPDATE client_traffics
|
||||||
|
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
||||||
|
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
||||||
|
`).Error; err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
// Calculate and backfill all_time from up+down for inbounds and clients
|
var inbounds []*model.Inbound
|
||||||
err = tx.Exec(`
|
err := tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error
|
||||||
UPDATE inbounds
|
if err != nil && err != gorm.ErrRecordNotFound {
|
||||||
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
return err
|
||||||
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
}
|
||||||
`).Error
|
for inboundIndex := range inbounds {
|
||||||
if err != nil {
|
settings := map[string]any{}
|
||||||
return
|
if err := json.Unmarshal([]byte(inbounds[inboundIndex].Settings), &settings); err != nil {
|
||||||
}
|
return err
|
||||||
err = tx.Exec(`
|
}
|
||||||
UPDATE client_traffics
|
clients, ok := settings["clients"].([]any)
|
||||||
SET all_time = IFNULL(up, 0) + IFNULL(down, 0)
|
if ok {
|
||||||
WHERE IFNULL(all_time, 0) = 0 AND (IFNULL(up, 0) + IFNULL(down, 0)) > 0
|
var newClients []any
|
||||||
`).Error
|
for clientIndex := range clients {
|
||||||
|
c, ok := clients[clientIndex].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid client settings format for inbound %d", inbounds[inboundIndex].Id)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if _, ok := c["email"]; !ok {
|
||||||
return
|
c["email"] = ""
|
||||||
}
|
}
|
||||||
|
if _, ok := c["tgId"]; ok {
|
||||||
// Fix inbounds based problems
|
tgId := c["tgId"]
|
||||||
var inbounds []*model.Inbound
|
if tgIdStr, ok2 := tgId.(string); ok2 {
|
||||||
err = tx.Model(model.Inbound{}).Where("protocol IN (?)", []string{"vmess", "vless", "trojan"}).Find(&inbounds).Error
|
tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64)
|
||||||
if err != nil && err != gorm.ErrRecordNotFound {
|
if err == nil {
|
||||||
return
|
c["tgId"] = tgIdInt64
|
||||||
}
|
}
|
||||||
for inbound_index := range inbounds {
|
|
||||||
settings := map[string]any{}
|
|
||||||
json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings)
|
|
||||||
clients, ok := settings["clients"].([]any)
|
|
||||||
if ok {
|
|
||||||
// Fix Client configuration problems
|
|
||||||
var newClients []any
|
|
||||||
for client_index := range clients {
|
|
||||||
c := clients[client_index].(map[string]any)
|
|
||||||
|
|
||||||
// Add email='' if it is not exists
|
|
||||||
if _, ok := c["email"]; !ok {
|
|
||||||
c["email"] = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert string tgId to int64
|
|
||||||
if _, ok := c["tgId"]; ok {
|
|
||||||
var tgId any = c["tgId"]
|
|
||||||
if tgIdStr, ok2 := tgId.(string); ok2 {
|
|
||||||
tgIdInt64, err := strconv.ParseInt(strings.ReplaceAll(tgIdStr, " ", ""), 10, 64)
|
|
||||||
if err == nil {
|
|
||||||
c["tgId"] = tgIdInt64
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if _, ok := c["flow"]; ok && c["flow"] == "xtls-rprx-direct" {
|
||||||
|
|
||||||
// Remove "flow": "xtls-rprx-direct"
|
|
||||||
if _, ok := c["flow"]; ok {
|
|
||||||
if c["flow"] == "xtls-rprx-direct" {
|
|
||||||
c["flow"] = ""
|
c["flow"] = ""
|
||||||
}
|
}
|
||||||
|
if _, ok := c["created_at"]; !ok {
|
||||||
|
c["created_at"] = time.Now().Unix() * 1000
|
||||||
|
}
|
||||||
|
c["updated_at"] = time.Now().Unix() * 1000
|
||||||
|
newClients = append(newClients, c)
|
||||||
}
|
}
|
||||||
// Backfill created_at and updated_at
|
settings["clients"] = newClients
|
||||||
if _, ok := c["created_at"]; !ok {
|
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
|
||||||
c["created_at"] = time.Now().Unix() * 1000
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
c["updated_at"] = time.Now().Unix() * 1000
|
inbounds[inboundIndex].Settings = string(modifiedSettings)
|
||||||
newClients = append(newClients, any(c))
|
|
||||||
}
|
}
|
||||||
settings["clients"] = newClients
|
|
||||||
modifiedSettings, err := json.MarshalIndent(settings, "", " ")
|
modelClients, err := s.GetClients(inbounds[inboundIndex])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
for _, modelClient := range modelClients {
|
||||||
inbounds[inbound_index].Settings = string(modifiedSettings)
|
if len(modelClient.Email) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
// Add client traffic row for all clients which has email
|
|
||||||
modelClients, err := s.GetClients(inbounds[inbound_index])
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, modelClient := range modelClients {
|
|
||||||
if len(modelClient.Email) > 0 {
|
|
||||||
var count int64
|
var count int64
|
||||||
tx.Model(xray.ClientTraffic{}).
|
if err := tx.Model(xray.ClientTraffic{}).
|
||||||
Where("inbound_id = ? AND email = ?", inbounds[inbound_index].Id, modelClient.Email).
|
Where("inbound_id = ? AND email = ?", inbounds[inboundIndex].Id, modelClient.Email).
|
||||||
Count(&count)
|
Count(&count).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
s.AddClientStat(tx, inbounds[inbound_index].Id, &modelClient)
|
if err := s.AddClientStat(tx, inbounds[inboundIndex].Id, &modelClient); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if err := tx.Save(inbounds).Error; err != nil {
|
||||||
tx.Save(inbounds)
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Remove orphaned traffics
|
var externalProxy []struct {
|
||||||
tx.Where("inbound_id = 0").Delete(xray.ClientTraffic{})
|
Id int
|
||||||
|
Port int
|
||||||
|
StreamSettings []byte
|
||||||
|
}
|
||||||
|
if err := tx.Raw(`select id, port, stream_settings
|
||||||
|
from inbounds
|
||||||
|
WHERE protocol in ('vmess','vless','trojan')
|
||||||
|
AND json_extract(stream_settings, '$.security') = 'tls'
|
||||||
|
AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Migrate old MultiDomain to External Proxy
|
for _, ep := range externalProxy {
|
||||||
var externalProxy []struct {
|
var reverses any
|
||||||
Id int
|
var stream map[string]any
|
||||||
Port int
|
if err := json.Unmarshal(ep.StreamSettings, &stream); err != nil {
|
||||||
StreamSettings []byte
|
return err
|
||||||
}
|
}
|
||||||
err = tx.Raw(`select id, port, stream_settings
|
if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok {
|
||||||
from inbounds
|
if settings, ok := tlsSettings["settings"].(map[string]any); ok {
|
||||||
WHERE protocol in ('vmess','vless','trojan')
|
if domains, ok := settings["domains"].([]any); ok {
|
||||||
AND json_extract(stream_settings, '$.security') = 'tls'
|
for _, domain := range domains {
|
||||||
AND json_extract(stream_settings, '$.tlsSettings.settings.domains') IS NOT NULL`).Scan(&externalProxy).Error
|
domainMap, ok := domain.(map[string]any)
|
||||||
if err != nil || len(externalProxy) == 0 {
|
if !ok {
|
||||||
return
|
return fmt.Errorf("invalid tls domain settings format for inbound %d", ep.Id)
|
||||||
}
|
}
|
||||||
|
domainName, ok := domainMap["domain"].(string)
|
||||||
for _, ep := range externalProxy {
|
if !ok {
|
||||||
var reverses any
|
return fmt.Errorf("invalid tls domain name for inbound %d", ep.Id)
|
||||||
var stream map[string]any
|
}
|
||||||
json.Unmarshal(ep.StreamSettings, &stream)
|
|
||||||
if tlsSettings, ok := stream["tlsSettings"].(map[string]any); ok {
|
|
||||||
if settings, ok := tlsSettings["settings"].(map[string]any); ok {
|
|
||||||
if domains, ok := settings["domains"].([]any); ok {
|
|
||||||
for _, domain := range domains {
|
|
||||||
if domainMap, ok := domain.(map[string]any); ok {
|
|
||||||
domainMap["forceTls"] = "same"
|
domainMap["forceTls"] = "same"
|
||||||
domainMap["port"] = ep.Port
|
domainMap["port"] = ep.Port
|
||||||
domainMap["dest"] = domainMap["domain"].(string)
|
domainMap["dest"] = domainName
|
||||||
delete(domainMap, "domain")
|
delete(domainMap, "domain")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
reverses = settings["domains"]
|
||||||
|
delete(settings, "domains")
|
||||||
}
|
}
|
||||||
reverses = settings["domains"]
|
}
|
||||||
delete(settings, "domains")
|
stream["externalProxy"] = reverses
|
||||||
|
newStream, err := json.MarshalIndent(stream, " ", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream).Error; err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stream["externalProxy"] = reverses
|
|
||||||
newStream, _ := json.MarshalIndent(stream, " ", " ")
|
return tx.Raw(`UPDATE inbounds
|
||||||
tx.Model(model.Inbound{}).Where("id = ?", ep.Id).Update("stream_settings", newStream)
|
SET tag = REPLACE(tag, '0.0.0.0:', '')
|
||||||
|
WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Raw(`UPDATE inbounds
|
if dbErr := db.Exec(`VACUUM "main"`).Error; dbErr != nil {
|
||||||
SET tag = REPLACE(tag, '0.0.0.0:', '')
|
logger.Warningf("VACUUM failed: %v", dbErr)
|
||||||
WHERE INSTR(tag, '0.0.0.0:') > 0;`).Error
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) MigrateDB() {
|
func (s *InboundService) MigrateDB() error {
|
||||||
s.MigrationRequirements()
|
if err := s.MigrationRequirements(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
s.MigrationRemoveOrphanedTraffics()
|
s.MigrationRemoveOrphanedTraffics()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *InboundService) GetOnlineClients() []string {
|
func (s *InboundService) GetOnlineClients() []string {
|
||||||
|
|
|
||||||
|
|
@ -194,3 +194,61 @@ func TestUpdateInboundClient_DoesNotUpdateOtherInboundTraffic(t *testing.T) {
|
||||||
t.Fatalf("expected renamed email to stay isolated to inbound1, got %d rows in inbound2", got)
|
t.Fatalf("expected renamed email to stay isolated to inbound1, got %d rows in inbound2", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrationRequirements_RollsBackOnAddClientStatFailure(t *testing.T) {
|
||||||
|
setupTestDB(t)
|
||||||
|
|
||||||
|
svc := &InboundService{}
|
||||||
|
inbound := model.Inbound{
|
||||||
|
UserId: 1,
|
||||||
|
Port: 12001,
|
||||||
|
Protocol: model.VLESS,
|
||||||
|
Tag: "rollback-test",
|
||||||
|
Up: 10,
|
||||||
|
Down: 20,
|
||||||
|
Settings: mustMarshalInboundSettings(t, model.Client{
|
||||||
|
ID: "client-rollback",
|
||||||
|
Email: "rollback@example.com",
|
||||||
|
Enable: true,
|
||||||
|
TotalGB: 100,
|
||||||
|
ExpiryTime: 200,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
if err := database.GetDB().Create(&inbound).Error; err != nil {
|
||||||
|
t.Fatalf("create inbound failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := database.GetDB().Exec(`
|
||||||
|
CREATE TRIGGER fail_client_traffic_insert
|
||||||
|
BEFORE INSERT ON client_traffics
|
||||||
|
BEGIN
|
||||||
|
SELECT RAISE(FAIL, 'boom');
|
||||||
|
END;
|
||||||
|
`).Error; err != nil {
|
||||||
|
t.Fatalf("create trigger failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.MigrationRequirements()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected migration requirements to return an error when client traffic insert fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
var refreshed model.Inbound
|
||||||
|
if err := database.GetDB().First(&refreshed, inbound.Id).Error; err != nil {
|
||||||
|
t.Fatalf("reload inbound failed: %v", err)
|
||||||
|
}
|
||||||
|
if refreshed.AllTime != 0 {
|
||||||
|
t.Fatalf("expected inbound all_time rollback to keep 0, got %d", refreshed.AllTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
var traffic xray.ClientTraffic
|
||||||
|
err = database.GetDB().
|
||||||
|
Where("inbound_id = ? AND email = ?", inbound.Id, "rollback@example.com").
|
||||||
|
First(&traffic).Error
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected client traffic insert to roll back, but row exists: %+v", traffic)
|
||||||
|
}
|
||||||
|
if !database.IsNotFound(err) {
|
||||||
|
t.Fatalf("reload client traffic failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1022,7 +1022,9 @@ func (s *ServerService) ImportDB(file multipart.File) error {
|
||||||
return common.NewErrorf("Error migrating db: %v", err)
|
return common.NewErrorf("Error migrating db: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.inboundService.MigrateDB()
|
if err := s.inboundService.MigrateDB(); err != nil {
|
||||||
|
return common.NewErrorf("Error finalizing imported db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Start Xray
|
// Start Xray
|
||||||
if err = s.RestartXrayService(); err != nil {
|
if err = s.RestartXrayService(); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/mhsanaei/3x-ui/v2/util/reflect_util"
|
"github.com/mhsanaei/3x-ui/v2/util/reflect_util"
|
||||||
"github.com/mhsanaei/3x-ui/v2/web/entity"
|
"github.com/mhsanaei/3x-ui/v2/web/entity"
|
||||||
"github.com/mhsanaei/3x-ui/v2/xray"
|
"github.com/mhsanaei/3x-ui/v2/xray"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed config.json
|
//go:embed config.json
|
||||||
|
|
@ -497,19 +498,13 @@ func getXrayTemplateConfigFromDB() (string, error) {
|
||||||
// saveXrayTemplateConfigToDB writes xrayTemplateConfig directly to the database.
|
// saveXrayTemplateConfigToDB writes xrayTemplateConfig directly to the database.
|
||||||
func saveXrayTemplateConfigToDB(value string) error {
|
func saveXrayTemplateConfigToDB(value string) error {
|
||||||
db := database.GetDB()
|
db := database.GetDB()
|
||||||
setting := &model.Setting{}
|
return db.Clauses(clause.OnConflict{
|
||||||
err := db.Model(model.Setting{}).Where("`key` = ?", "xrayTemplateConfig").First(setting).Error
|
Columns: []clause.Column{{Name: "key"}},
|
||||||
if database.IsNotFound(err) {
|
DoUpdates: clause.Assignments(map[string]any{"value": value}),
|
||||||
return db.Create(&model.Setting{
|
}).Create(&model.Setting{
|
||||||
Key: "xrayTemplateConfig",
|
Key: "xrayTemplateConfig",
|
||||||
Value: value,
|
Value: value,
|
||||||
}).Error
|
}).Error
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
setting.Value = value
|
|
||||||
return db.Save(setting).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SettingService provides business logic for application settings management.
|
// SettingService provides business logic for application settings management.
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mhsanaei/3x-ui/v2/config"
|
"github.com/mhsanaei/3x-ui/v2/config"
|
||||||
|
"github.com/mhsanaei/3x-ui/v2/database"
|
||||||
|
"github.com/mhsanaei/3x-ui/v2/database/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupTestSettings(t *testing.T) func() {
|
func setupTestSettings(t *testing.T) func() {
|
||||||
|
|
@ -102,6 +104,31 @@ func TestSettingServiceSetAndGetString(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveXrayTemplateConfigToDB_UpdatesSingleRow(t *testing.T) {
|
||||||
|
setupTestSettings(t)
|
||||||
|
setupTestDB(t)
|
||||||
|
|
||||||
|
if err := saveXrayTemplateConfigToDB(`{"version":1}`); err != nil {
|
||||||
|
t.Fatalf("first save failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := saveXrayTemplateConfigToDB(`{"version":2}`); err != nil {
|
||||||
|
t.Fatalf("second save failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings []model.Setting
|
||||||
|
if err := database.GetDB().
|
||||||
|
Where("key = ?", "xrayTemplateConfig").
|
||||||
|
Find(&settings).Error; err != nil {
|
||||||
|
t.Fatalf("query settings failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(settings) != 1 {
|
||||||
|
t.Fatalf("expected exactly one xrayTemplateConfig row, got %d", len(settings))
|
||||||
|
}
|
||||||
|
if settings[0].Value != `{"version":2}` {
|
||||||
|
t.Fatalf("expected latest config value to be persisted, got %s", settings[0].Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResetSettingsDeletesFile(t *testing.T) {
|
func TestResetSettingsDeletesFile(t *testing.T) {
|
||||||
setupTestSettings(t)
|
setupTestSettings(t)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue