fix: address all code review issues for MariaDB support

- Prevent DBPassword from leaking to frontend (json:"-")
- Make migration direction explicit via --direction flag, set dbType only after success
- Use driver-appropriate DROP INDEX IF EXISTS for SQLite vs MariaDB
- Build DSN with mysql.Config.FormatDSN() to prevent injection with special chars
- Close DB before re-initialization in InitDB
- Add migration tests (5 tests using SQLite in-memory DBs)
- Parse JSON once in GetDBConfigFromJSON instead of 7 times
- Use Go binary for dbType in shell script instead of fragile grep
- Add rollback on failure in db_switch_to_sqlite
- Validate DB settings in CheckValid
This commit is contained in:
Sora39831 2026-04-03 09:53:20 +08:00
parent 7f015ad27b
commit c94372a22c
8 changed files with 325 additions and 50 deletions

View file

@ -200,36 +200,46 @@ type DBConfig struct {
// GetDBConfigFromJSON reads all MariaDB connection settings from the JSON config file.
func GetDBConfigFromJSON() DBConfig {
readString := func(data []byte, nestedGroup, flatKey string) string {
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
return ""
}
// Nested format
data, err := os.ReadFile(GetSettingPath())
if err != nil {
return DBConfig{Type: "sqlite", Host: "127.0.0.1", Port: "3306", Name: "3xui"}
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
return DBConfig{Type: "sqlite", Host: "127.0.0.1", Port: "3306", Name: "3xui"}
}
// readString extracts a value from either nested (group.key) or flat format
readString := func(nestedGroup, flatKey string) string {
if group, ok := settings[nestedGroup].(map[string]any); ok {
if v, ok := group[flatKey].(string); ok {
return v
}
}
// Flat format
if v, ok := settings[flatKey].(string); ok {
return v
}
return ""
}
data, err := os.ReadFile(GetSettingPath())
if err != nil {
return DBConfig{Type: "sqlite", Host: "127.0.0.1", Port: "3306", Name: "3xui"}
// Read dbType from the same parsed settings
dbType := "sqlite"
if other, ok := settings["other"].(map[string]any); ok {
if t, ok := other["dbType"].(string); ok && t != "" {
dbType = t
}
} else if t, ok := settings["dbType"].(string); ok && t != "" {
dbType = t
}
return DBConfig{
Type: GetDBTypeFromJSON(),
Host: readString(data, "other", "dbHost"),
Port: readString(data, "other", "dbPort"),
User: readString(data, "other", "dbUser"),
Password: readString(data, "other", "dbPassword"),
Name: readString(data, "other", "dbName"),
Type: dbType,
Host: readString("other", "dbHost"),
Port: readString("other", "dbPort"),
User: readString("other", "dbUser"),
Password: readString("other", "dbPassword"),
Name: readString("other", "dbName"),
}
}

View file

@ -5,7 +5,6 @@ package database
import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"log"
@ -18,6 +17,7 @@ import (
"github.com/mhsanaei/3x-ui/v2/util/crypto"
"github.com/mhsanaei/3x-ui/v2/xray"
mysql2 "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
@ -124,7 +124,12 @@ func runSeeders(isUsersEmpty bool) error {
if !slices.Contains(seedersHistory, "RemoveClientTrafficEmailUnique") {
// Drop the old unique index on client_traffics.email to allow
// the same email across multiple inbounds
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email")
dbType := config.GetDBTypeFromJSON()
if dbType == "mariadb" {
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics")
} else {
db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email")
}
uniqueSeeder := &model.HistoryOfSeeders{
SeederName: "RemoveClientTrafficEmailUnique",
}
@ -147,6 +152,8 @@ func isTableEmpty(tableName string) (bool, error) {
// InitDB sets up the database connection, migrates models, and runs seeders.
// It reads the dbType from the JSON config to determine whether to use SQLite or MariaDB.
func InitDB() error {
CloseDB() // close any existing connection before re-initializing
dbType := config.GetDBTypeFromJSON()
var err error
@ -178,6 +185,8 @@ func InitDB() error {
// InitDBWithPath is a convenience function for tests and migrations that need
// to open a specific SQLite file.
func InitDBWithPath(dbPath string) error {
CloseDB() // close any existing connection before re-initializing
if err := initSQLite(dbPath); err != nil {
return err
}
@ -221,12 +230,28 @@ func initSQLite(dbPath string) error {
return nil
}
// buildMariaDBDSN constructs a MariaDB DSN from the given config using
// go-sql-driver/mysql's Config to properly escape special characters in credentials.
func buildMariaDBDSN(dbConfig config.DBConfig) string {
cfg := mysql2.Config{
User: dbConfig.User,
Passwd: dbConfig.Password,
Net: "tcp",
Addr: dbConfig.Host + ":" + dbConfig.Port,
DBName: dbConfig.Name,
Params: map[string]string{
"charset": "utf8mb4",
"parseTime": "True",
"loc": "Local",
},
}
return cfg.FormatDSN()
}
// initMariaDB opens a MariaDB connection and runs model migrations.
func initMariaDB() error {
dbConfig := config.GetDBConfigFromJSON()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name)
dsn := buildMariaDBDSN(dbConfig)
var gormLogger logger.Interface
if config.IsDebug() {

View file

@ -48,8 +48,7 @@ func openSQLite(dbPath string) (*gorm.DB, error) {
// openMariaDB opens a MariaDB connection for migration.
func openMariaDB() (*gorm.DB, error) {
dbConfig := config.GetDBConfigFromJSON()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name)
dsn := buildMariaDBDSN(dbConfig)
return gorm.Open(mysql.Open(dsn), &gorm.Config{Logger: logger.Discard})
}

181
database/migrate_test.go Normal file
View file

@ -0,0 +1,181 @@
package database
import (
"path/filepath"
"testing"
"github.com/mhsanaei/3x-ui/v2/database/model"
"github.com/mhsanaei/3x-ui/v2/xray"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// openTestSQLiteDB opens an in-memory or file-based SQLite database for testing.
func openTestSQLiteDB(t *testing.T, dbPath string) *gorm.DB {
t.Helper()
gdb, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{Logger: logger.Discard})
if err != nil {
t.Fatalf("failed to open test SQLite DB: %v", err)
}
t.Cleanup(func() {
sqlDB, _ := gdb.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return gdb
}
// createTestTables runs AutoMigrate on the given DB for all models.
func createTestTables(t *testing.T, gdb *gorm.DB) {
t.Helper()
models := []any{
&model.User{},
&model.Inbound{},
&model.OutboundTraffics{},
&model.Setting{},
&model.InboundClientIps{},
&xray.ClientTraffic{},
&model.HistoryOfSeeders{},
}
for _, m := range models {
if err := gdb.AutoMigrate(m); err != nil {
t.Fatalf("AutoMigrate failed: %v", err)
}
}
}
func TestMigrateAllTables_EmoprySource(t *testing.T) {
srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db"))
dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db"))
createTestTables(t, srcDB)
createTestTables(t, dstDB)
err := migrateAllTables(srcDB, dstDB)
if err != nil {
t.Fatalf("migrateAllTables on empty source should succeed: %v", err)
}
// Verify destination is still empty
for _, name := range tableNames() {
var count int64
dstDB.Table(name).Count(&count)
if count != 0 {
t.Errorf("table %s should be empty, got %d rows", name, count)
}
}
}
func TestMigrateAllTables_WithData(t *testing.T) {
srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db"))
dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db"))
createTestTables(t, srcDB)
createTestTables(t, dstDB)
// Insert test data into source
srcDB.Create(&model.User{Username: "testuser", Password: "testpass", Role: "admin"})
srcDB.Create(&model.Setting{Key: "testkey", Value: "testvalue"})
err := migrateAllTables(srcDB, dstDB)
if err != nil {
t.Fatalf("migrateAllTables failed: %v", err)
}
// Verify data was copied
var userCount int64
dstDB.Table("users").Count(&userCount)
if userCount != 1 {
t.Errorf("expected 1 user in dst, got %d", userCount)
}
var settingCount int64
dstDB.Table("settings").Count(&settingCount)
if settingCount != 1 {
t.Errorf("expected 1 setting in dst, got %d", settingCount)
}
}
func TestMigrateAllTables_OverwritesExisting(t *testing.T) {
srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db"))
dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db"))
createTestTables(t, srcDB)
createTestTables(t, dstDB)
// Insert existing data in destination that should be cleared
dstDB.Create(&model.User{Username: "olduser", Password: "oldpass", Role: "admin"})
dstDB.Create(&model.Setting{Key: "oldkey", Value: "oldvalue"})
// Insert new data in source
srcDB.Create(&model.User{Username: "newuser", Password: "newpass", Role: "admin"})
err := migrateAllTables(srcDB, dstDB)
if err != nil {
t.Fatalf("migrateAllTables failed: %v", err)
}
// Verify old data was replaced
var userCount int64
dstDB.Table("users").Count(&userCount)
if userCount != 1 {
t.Errorf("expected 1 user in dst after overwrite, got %d", userCount)
}
var user model.User
dstDB.Table("users").First(&user)
if user.Username != "newuser" {
t.Errorf("expected username 'newuser', got '%s'", user.Username)
}
// Settings should be empty since source has no settings
var settingCount int64
dstDB.Table("settings").Count(&settingCount)
if settingCount != 0 {
t.Errorf("expected 0 settings in dst after overwrite, got %d", settingCount)
}
}
func TestMigrateTable_Generic(t *testing.T) {
srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db"))
dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db"))
createTestTables(t, srcDB)
createTestTables(t, dstDB)
// Insert test users
srcDB.Create(&model.User{Username: "user1", Password: "pass1", Role: "admin"})
srcDB.Create(&model.User{Username: "user2", Password: "pass2", Role: "admin"})
count, err := migrateTable[model.User](srcDB, dstDB, "users")
if err != nil {
t.Fatalf("migrateTable failed: %v", err)
}
if count != 2 {
t.Errorf("expected 2 rows migrated, got %d", count)
}
var dstCount int64
dstDB.Table("users").Count(&dstCount)
if dstCount != 2 {
t.Errorf("expected 2 users in dst, got %d", dstCount)
}
}
func TestMigrateTable_EmptyTable(t *testing.T) {
srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db"))
dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db"))
createTestTables(t, srcDB)
createTestTables(t, dstDB)
count, err := migrateTable[model.User](srcDB, dstDB, "users")
if err != nil {
t.Fatalf("migrateTable on empty table should succeed: %v", err)
}
if count != 0 {
t.Errorf("expected 0 rows migrated, got %d", count)
}
}

48
main.go
View file

@ -401,24 +401,41 @@ func migrateDb() {
fmt.Println("Migration done!")
}
// migrateDbBetweenDrivers migrates data between SQLite and MariaDB based on the configured dbType.
func migrateDbBetweenDrivers() {
dbType := config.GetDBTypeFromJSON()
switch dbType {
case "mariadb":
// migrateDbBetweenDrivers migrates data between SQLite and MariaDB.
// The direction can be specified via --direction flag, otherwise it falls back to dbType from config.
func migrateDbBetweenDrivers(direction string) {
switch direction {
case "sqlite-to-mariadb":
fmt.Println("Migrating data from SQLite to MariaDB...")
if err := database.MigrateSQLiteToMariaDB(); err != nil {
log.Fatal("Migration failed: ", err)
}
fmt.Println("Migration to MariaDB completed successfully.")
case "sqlite":
case "mariadb-to-sqlite":
fmt.Println("Migrating data from MariaDB to SQLite...")
if err := database.MigrateMariaDBToSQLite(); err != nil {
log.Fatal("Migration failed: ", err)
}
fmt.Println("Migration to SQLite completed successfully.")
default:
log.Fatalf("Unknown dbType: %s", dbType)
// Fall back to inferring from dbType config
dbType := config.GetDBTypeFromJSON()
switch dbType {
case "mariadb":
fmt.Println("Migrating data from SQLite to MariaDB...")
if err := database.MigrateSQLiteToMariaDB(); err != nil {
log.Fatal("Migration failed: ", err)
}
fmt.Println("Migration to MariaDB completed successfully.")
case "sqlite":
fmt.Println("Migrating data from MariaDB to SQLite...")
if err := database.MigrateMariaDBToSQLite(); err != nil {
log.Fatal("Migration failed: ", err)
}
fmt.Println("Migration to SQLite completed successfully.")
default:
log.Fatalf("Unknown dbType: %s", dbType)
}
}
}
@ -474,12 +491,18 @@ func main() {
var dbUser string
var dbPassword string
var dbName string
var showDbType bool
settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)")
settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host")
settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port")
settingCmd.StringVar(&dbUser, "dbUser", "", "Set MariaDB username")
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit")
migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError)
var migrateDirection string
migrateDbCmd.StringVar(&migrateDirection, "direction", "", "Migration direction: sqlite-to-mariadb or mariadb-to-sqlite")
// Allow dbPassword to be passed via env var to avoid leaking it in process args
if p := os.Getenv("XUI_DB_PASSWORD"); p != "" {
@ -514,13 +537,22 @@ func main() {
case "migrate":
migrateDb()
case "migrate-db":
migrateDbBetweenDrivers()
err := migrateDbCmd.Parse(os.Args[2:])
if err != nil {
fmt.Println(err)
return
}
migrateDbBetweenDrivers(migrateDirection)
case "setting":
err := settingCmd.Parse(os.Args[2:])
if err != nil {
fmt.Println(err)
return
}
if showDbType {
fmt.Println(config.GetDBTypeFromJSON())
return
}
if reset {
resetSetting()
} else {

View file

@ -109,7 +109,7 @@ type AllSetting struct {
DBHost string `json:"dbHost" form:"dbHost"`
DBPort string `json:"dbPort" form:"dbPort"`
DBUser string `json:"dbUser" form:"dbUser"`
DBPassword string `json:"dbPassword" form:"dbPassword"`
DBPassword string `json:"-" form:"dbPassword"`
DBName string `json:"dbName" form:"dbName"`
// Registration settings
@ -184,5 +184,27 @@ func (s *AllSetting) CheckValid() error {
return common.NewError("time location not exist:", s.TimeLocation)
}
// Validate database settings
if s.DBType != "" && s.DBType != "sqlite" && s.DBType != "mariadb" {
return common.NewError("db type must be sqlite or mariadb, got:", s.DBType)
}
if s.DBType == "mariadb" {
if s.DBHost == "" {
return common.NewError("db host is required for MariaDB")
}
if s.DBPort != "" {
port := 0
for _, c := range s.DBPort {
if c < '0' || c > '9' {
return common.NewError("db port is not a valid number:", s.DBPort)
}
port = port*10 + int(c-'0')
}
if port <= 0 || port > math.MaxUint16 {
return common.NewError("db port is not a valid port:", s.DBPort)
}
}
}
return nil
}

View file

@ -998,6 +998,10 @@ func (s *SettingService) UpdateAllSetting(allSetting *entity.AllSetting) error {
fieldV := v.FieldByName(field.Name)
settings[key] = fmt.Sprint(fieldV.Interface())
}
// DBPassword uses json:"-" to avoid leaking to frontend, handle it via form tag
if allSetting.DBPassword != "" {
settings["dbPassword"] = allSetting.DBPassword
}
return saveSettings(settings)
}

38
x-ui.sh
View file

@ -2202,15 +2202,10 @@ show_usage() {
└────────────────────────────────────────────────────────────────┘"
}
# Read dbType from /etc/x-ui/x-ui.json
# Read dbType from /etc/x-ui/x-ui.json using the Go binary
read_json_dbtype() {
local json_path="/etc/x-ui/x-ui.json"
if [ ! -f "$json_path" ]; then
echo "sqlite"
return
fi
# Try nested format first (other.dbType)
local db_type=$(grep -o '"dbType"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" | head -1 | sed 's/.*"dbType"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/')
local db_type
db_type=$(${xui_folder}/x-ui setting -showDbType 2>/dev/null)
if [ -z "$db_type" ]; then
echo "sqlite"
else
@ -2224,9 +2219,15 @@ db_show_status() {
echo -e "${green}当前数据库类型: ${current_type}${plain}"
if [ "$current_type" = "mariadb" ]; then
local json_path="/etc/x-ui/x-ui.json"
local host=$(grep -o '"dbHost"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | head -1 | sed 's/.*"dbHost"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/')
local port=$(grep -o '"dbPort"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | head -1 | sed 's/.*"dbPort"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/')
local dbname=$(grep -o '"dbName"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | head -1 | sed 's/.*"dbName"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/')
if command -v jq >/dev/null 2>&1; then
local host=$(jq -r '.other.dbHost // "127.0.0.1"' "$json_path" 2>/dev/null)
local port=$(jq -r '.other.dbPort // "3306"' "$json_path" 2>/dev/null)
local dbname=$(jq -r '.other.dbName // "3xui"' "$json_path" 2>/dev/null)
else
local host=$(grep -o '"dbHost"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | tail -1 | sed 's/.*"\([^"]*\)"$/\1/')
local port=$(grep -o '"dbPort"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | tail -1 | sed 's/.*"\([^"]*\)"$/\1/')
local dbname=$(grep -o '"dbName"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | tail -1 | sed 's/.*"\([^"]*\)"$/\1/')
fi
echo -e "${green}MariaDB 主机: ${host:-127.0.0.1}:${port:-3306}${plain}"
echo -e "${green}数据库名: ${dbname:-3xui}${plain}"
fi
@ -2268,17 +2269,17 @@ db_switch_to_mariadb() {
db_name=${db_name:-3xui}
echo -e "${green}正在配置 MariaDB 连接...${plain}"
XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting -dbType mariadb -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbName "$db_name" >/dev/null 2>&1
XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbName "$db_name" >/dev/null 2>&1
echo -e "${green}正在迁移数据从 SQLite 到 MariaDB...${plain}"
${xui_folder}/x-ui migrate-db
${xui_folder}/x-ui migrate-db -direction sqlite-to-mariadb
if [ $? -eq 0 ]; then
echo -e "${green}数据库切换成功,正在重启面板...${plain}"
${xui_folder}/x-ui setting -dbType mariadb >/dev/null 2>&1
restart
else
echo -e "${red}数据迁移失败,正在回滚到 SQLite...${plain}"
${xui_folder}/x-ui setting -dbType sqlite >/dev/null 2>&1
echo -e "${red}数据迁移失败,保持 SQLite 不变${plain}"
restart
fi
}
@ -2293,14 +2294,15 @@ db_switch_to_sqlite() {
fi
echo -e "${green}正在迁移数据从 MariaDB 到 SQLite...${plain}"
${xui_folder}/x-ui setting -dbType sqlite >/dev/null 2>&1
${xui_folder}/x-ui migrate-db
${xui_folder}/x-ui migrate-db -direction mariadb-to-sqlite
if [ $? -eq 0 ]; then
echo -e "${green}数据库切换成功,正在重启面板...${plain}"
${xui_folder}/x-ui setting -dbType sqlite >/dev/null 2>&1
restart
else
echo -e "${red}数据迁移失败${plain}"
echo -e "${red}数据迁移失败,保持 MariaDB 不变${plain}"
db_menu
fi
}