mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-06 21:24:10 +00:00
fix: address code review issues for MariaDB support
- Close migration DB connections with defer to prevent leaks - Truncate destination tables before migration to avoid duplicates - Wrap migration in transaction for atomicity - Pass DB password via env var instead of CLI args to avoid process list exposure - Improve error messages for MariaDB export/import with alternatives - Update package doc to reflect dual DB support - DRY migration logic with shared migrateAllTables function
This commit is contained in:
parent
cf720b6c5b
commit
7f015ad27b
5 changed files with 72 additions and 70 deletions
|
|
@ -1,5 +1,5 @@
|
|||
// Package database provides database initialization, migration, and management utilities
|
||||
// for the 3x-ui panel using GORM with SQLite.
|
||||
// for the 3x-ui panel using GORM with SQLite or MariaDB.
|
||||
package database
|
||||
|
||||
import (
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func tableNames() []string {
|
|||
}
|
||||
}
|
||||
|
||||
// openSQLite opens a read-only SQLite connection for migration.
|
||||
// openSQLite opens a SQLite connection for migration.
|
||||
func openSQLite(dbPath string) (*gorm.DB, error) {
|
||||
return gorm.Open(sqlite.Open(dbPath), &gorm.Config{Logger: logger.Discard})
|
||||
}
|
||||
|
|
@ -53,6 +53,18 @@ func openMariaDB() (*gorm.DB, error) {
|
|||
return gorm.Open(mysql.Open(dsn), &gorm.Config{Logger: logger.Discard})
|
||||
}
|
||||
|
||||
// closeDB safely closes the underlying SQL connection.
|
||||
func closeDB(gdb *gorm.DB) {
|
||||
if gdb == nil {
|
||||
return
|
||||
}
|
||||
sqlDB, err := gdb.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
// migrateTable copies all rows from src table to dst table using the given model slice.
|
||||
// It returns the number of rows migrated.
|
||||
func migrateTable[T any](src, dst *gorm.DB, tableName string) (int64, error) {
|
||||
|
|
@ -69,18 +81,63 @@ func migrateTable[T any](src, dst *gorm.DB, tableName string) (int64, error) {
|
|||
return int64(len(rows)), nil
|
||||
}
|
||||
|
||||
// migrateAllTables copies all data between two database connections within a transaction.
|
||||
func migrateAllTables(src, dst *gorm.DB) error {
|
||||
// Truncate destination tables and migrate within a transaction
|
||||
return dst.Transaction(func(tx *gorm.DB) error {
|
||||
// Clear destination tables in reverse dependency order
|
||||
for i := len(tableNames()) - 1; i >= 0; i-- {
|
||||
name := tableNames()[i]
|
||||
if err := tx.Exec("DELETE FROM " + name).Error; err != nil {
|
||||
return fmt.Errorf("failed to clear %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
total := int64(0)
|
||||
for _, name := range tableNames() {
|
||||
var count int64
|
||||
var err error
|
||||
switch name {
|
||||
case "users":
|
||||
count, err = migrateTable[model.User](src, tx, name)
|
||||
case "inbounds":
|
||||
count, err = migrateTable[model.Inbound](src, tx, name)
|
||||
case "outbound_traffics":
|
||||
count, err = migrateTable[model.OutboundTraffics](src, tx, name)
|
||||
case "settings":
|
||||
count, err = migrateTable[model.Setting](src, tx, name)
|
||||
case "inbound_client_ips":
|
||||
count, err = migrateTable[model.InboundClientIps](src, tx, name)
|
||||
case "client_traffics":
|
||||
count, err = migrateTable[xray.ClientTraffic](src, tx, name)
|
||||
case "history_of_seeders":
|
||||
count, err = migrateTable[model.HistoryOfSeeders](src, tx, name)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed for %s: %w", name, err)
|
||||
}
|
||||
total += count
|
||||
log.Printf("Migrated %d rows from %s", count, name)
|
||||
}
|
||||
log.Printf("Migration complete: %d total rows", total)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// MigrateSQLiteToMariaDB copies all data from the SQLite database to MariaDB.
|
||||
// The SQLite file is kept as a backup. The MariaDB tables are created via AutoMigrate first.
|
||||
// The SQLite file is kept as a backup.
|
||||
func MigrateSQLiteToMariaDB() error {
|
||||
srcDB, err := openSQLite(config.GetDBPath())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open SQLite source: %w", err)
|
||||
}
|
||||
defer closeDB(srcDB)
|
||||
|
||||
dstDB, err := openMariaDB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open MariaDB destination: %w", err)
|
||||
}
|
||||
defer closeDB(dstDB)
|
||||
|
||||
// AutoMigrate all tables on destination
|
||||
for _, m := range allModels() {
|
||||
|
|
@ -89,38 +146,7 @@ func MigrateSQLiteToMariaDB() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Clear the seeder history so seeders run fresh on the new DB
|
||||
dstDB.Exec("DELETE FROM history_of_seeders")
|
||||
|
||||
// Migrate each table
|
||||
total := int64(0)
|
||||
for _, name := range tableNames() {
|
||||
var count int64
|
||||
switch name {
|
||||
case "users":
|
||||
count, err = migrateTable[model.User](srcDB, dstDB, name)
|
||||
case "inbounds":
|
||||
count, err = migrateTable[model.Inbound](srcDB, dstDB, name)
|
||||
case "outbound_traffics":
|
||||
count, err = migrateTable[model.OutboundTraffics](srcDB, dstDB, name)
|
||||
case "settings":
|
||||
count, err = migrateTable[model.Setting](srcDB, dstDB, name)
|
||||
case "inbound_client_ips":
|
||||
count, err = migrateTable[model.InboundClientIps](srcDB, dstDB, name)
|
||||
case "client_traffics":
|
||||
count, err = migrateTable[xray.ClientTraffic](srcDB, dstDB, name)
|
||||
case "history_of_seeders":
|
||||
count, err = migrateTable[model.HistoryOfSeeders](srcDB, dstDB, name)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed for %s: %w", name, err)
|
||||
}
|
||||
total += count
|
||||
log.Printf("Migrated %d rows from %s", count, name)
|
||||
}
|
||||
|
||||
log.Printf("SQLite to MariaDB migration complete: %d total rows", total)
|
||||
return nil
|
||||
return migrateAllTables(srcDB, dstDB)
|
||||
}
|
||||
|
||||
// MigrateMariaDBToSQLite copies all data from MariaDB to the SQLite database.
|
||||
|
|
@ -130,11 +156,13 @@ func MigrateMariaDBToSQLite() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to open MariaDB source: %w", err)
|
||||
}
|
||||
defer closeDB(srcDB)
|
||||
|
||||
dstDB, err := openSQLite(config.GetDBPath())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open SQLite destination: %w", err)
|
||||
}
|
||||
defer closeDB(dstDB)
|
||||
|
||||
// AutoMigrate all tables on destination
|
||||
for _, m := range allModels() {
|
||||
|
|
@ -143,36 +171,5 @@ func MigrateMariaDBToSQLite() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Clear the seeder history so seeders run fresh on the new DB
|
||||
dstDB.Exec("DELETE FROM history_of_seeders")
|
||||
|
||||
// Migrate each table
|
||||
total := int64(0)
|
||||
for _, name := range tableNames() {
|
||||
var count int64
|
||||
switch name {
|
||||
case "users":
|
||||
count, err = migrateTable[model.User](srcDB, dstDB, name)
|
||||
case "inbounds":
|
||||
count, err = migrateTable[model.Inbound](srcDB, dstDB, name)
|
||||
case "outbound_traffics":
|
||||
count, err = migrateTable[model.OutboundTraffics](srcDB, dstDB, name)
|
||||
case "settings":
|
||||
count, err = migrateTable[model.Setting](srcDB, dstDB, name)
|
||||
case "inbound_client_ips":
|
||||
count, err = migrateTable[model.InboundClientIps](srcDB, dstDB, name)
|
||||
case "client_traffics":
|
||||
count, err = migrateTable[xray.ClientTraffic](srcDB, dstDB, name)
|
||||
case "history_of_seeders":
|
||||
count, err = migrateTable[model.HistoryOfSeeders](srcDB, dstDB, name)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed for %s: %w", name, err)
|
||||
}
|
||||
total += count
|
||||
log.Printf("Migrated %d rows from %s", count, name)
|
||||
}
|
||||
|
||||
log.Printf("MariaDB to SQLite migration complete: %d total rows", total)
|
||||
return nil
|
||||
return migrateAllTables(srcDB, dstDB)
|
||||
}
|
||||
|
|
|
|||
5
main.go
5
main.go
|
|
@ -481,6 +481,11 @@ func main() {
|
|||
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
|
||||
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
|
||||
|
||||
// Allow dbPassword to be passed via env var to avoid leaking it in process args
|
||||
if p := os.Getenv("XUI_DB_PASSWORD"); p != "" {
|
||||
dbPassword = p
|
||||
}
|
||||
|
||||
oldUsage := flag.Usage
|
||||
flag.Usage = func() {
|
||||
oldUsage()
|
||||
|
|
|
|||
|
|
@ -883,7 +883,7 @@ func (s *ServerService) GetConfigJson() (any, error) {
|
|||
|
||||
func (s *ServerService) GetDb() ([]byte, error) {
|
||||
if config.GetDBTypeFromJSON() == "mariadb" {
|
||||
return nil, common.NewError("Database export is not supported for MariaDB")
|
||||
return nil, common.NewError("database export is not supported when using MariaDB — use mysqldump instead")
|
||||
}
|
||||
// Update by manually trigger a checkpoint operation
|
||||
err := database.Checkpoint()
|
||||
|
|
@ -908,7 +908,7 @@ func (s *ServerService) GetDb() ([]byte, error) {
|
|||
|
||||
func (s *ServerService) ImportDB(file multipart.File) error {
|
||||
if config.GetDBTypeFromJSON() == "mariadb" {
|
||||
return common.NewError("Database import is not supported for MariaDB")
|
||||
return common.NewError("database import is not supported when using MariaDB — use mysql restore instead")
|
||||
}
|
||||
// Check if the file is a SQLite database
|
||||
isValidDb, err := database.IsSQLiteDB(file)
|
||||
|
|
|
|||
2
x-ui.sh
2
x-ui.sh
|
|
@ -2268,7 +2268,7 @@ db_switch_to_mariadb() {
|
|||
db_name=${db_name:-3xui}
|
||||
|
||||
echo -e "${green}正在配置 MariaDB 连接...${plain}"
|
||||
${xui_folder}/x-ui setting -dbType mariadb -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbPassword "$db_pass" -dbName "$db_name" >/dev/null 2>&1
|
||||
XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting -dbType mariadb -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbName "$db_name" >/dev/null 2>&1
|
||||
|
||||
echo -e "${green}正在迁移数据从 SQLite 到 MariaDB...${plain}"
|
||||
${xui_folder}/x-ui migrate-db
|
||||
|
|
|
|||
Loading…
Reference in a new issue