diff --git a/config/config.go b/config/config.go index 812071b4..64844f68 100644 --- a/config/config.go +++ b/config/config.go @@ -200,36 +200,46 @@ type DBConfig struct { // GetDBConfigFromJSON reads all MariaDB connection settings from the JSON config file. func GetDBConfigFromJSON() DBConfig { - readString := func(data []byte, nestedGroup, flatKey string) string { - var settings map[string]any - if err := json.Unmarshal(data, &settings); err != nil { - return "" - } - // Nested format + data, err := os.ReadFile(GetSettingPath()) + if err != nil { + return DBConfig{Type: "sqlite", Host: "127.0.0.1", Port: "3306", Name: "3xui"} + } + + var settings map[string]any + if err := json.Unmarshal(data, &settings); err != nil { + return DBConfig{Type: "sqlite", Host: "127.0.0.1", Port: "3306", Name: "3xui"} + } + + // readString extracts a value from either nested (group.key) or flat format + readString := func(nestedGroup, flatKey string) string { if group, ok := settings[nestedGroup].(map[string]any); ok { if v, ok := group[flatKey].(string); ok { return v } } - // Flat format if v, ok := settings[flatKey].(string); ok { return v } return "" } - data, err := os.ReadFile(GetSettingPath()) - if err != nil { - return DBConfig{Type: "sqlite", Host: "127.0.0.1", Port: "3306", Name: "3xui"} + // Read dbType from the same parsed settings + dbType := "sqlite" + if other, ok := settings["other"].(map[string]any); ok { + if t, ok := other["dbType"].(string); ok && t != "" { + dbType = t + } + } else if t, ok := settings["dbType"].(string); ok && t != "" { + dbType = t } return DBConfig{ - Type: GetDBTypeFromJSON(), - Host: readString(data, "other", "dbHost"), - Port: readString(data, "other", "dbPort"), - User: readString(data, "other", "dbUser"), - Password: readString(data, "other", "dbPassword"), - Name: readString(data, "other", "dbName"), + Type: dbType, + Host: readString("other", "dbHost"), + Port: readString("other", "dbPort"), + User: readString("other", "dbUser"), + Password: readString("other", "dbPassword"), + Name: readString("other", "dbName"), } } diff --git a/database/db.go b/database/db.go index f26d0ba3..74812669 100644 --- a/database/db.go +++ b/database/db.go @@ -5,7 +5,6 @@ package database import ( "bytes" "errors" - "fmt" "io" "io/fs" "log" @@ -18,6 +17,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/util/crypto" "github.com/mhsanaei/3x-ui/v2/xray" + mysql2 "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -124,7 +124,12 @@ func runSeeders(isUsersEmpty bool) error { if !slices.Contains(seedersHistory, "RemoveClientTrafficEmailUnique") { // Drop the old unique index on client_traffics.email to allow // the same email across multiple inbounds - db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email") + dbType := config.GetDBTypeFromJSON() + if dbType == "mariadb" { + db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email ON client_traffics") + } else { + db.Exec("DROP INDEX IF EXISTS idx_client_traffics_email") + } uniqueSeeder := &model.HistoryOfSeeders{ SeederName: "RemoveClientTrafficEmailUnique", } @@ -147,6 +152,8 @@ func isTableEmpty(tableName string) (bool, error) { // InitDB sets up the database connection, migrates models, and runs seeders. // It reads the dbType from the JSON config to determine whether to use SQLite or MariaDB. func InitDB() error { + CloseDB() // close any existing connection before re-initializing + dbType := config.GetDBTypeFromJSON() var err error @@ -178,6 +185,8 @@ func InitDB() error { // InitDBWithPath is a convenience function for tests and migrations that need // to open a specific SQLite file. func InitDBWithPath(dbPath string) error { + CloseDB() // close any existing connection before re-initializing + if err := initSQLite(dbPath); err != nil { return err } @@ -221,12 +230,28 @@ func initSQLite(dbPath string) error { return nil } +// buildMariaDBDSN constructs a MariaDB DSN from the given config using +// go-sql-driver/mysql's Config to properly escape special characters in credentials. +func buildMariaDBDSN(dbConfig config.DBConfig) string { + cfg := mysql2.Config{ + User: dbConfig.User, + Passwd: dbConfig.Password, + Net: "tcp", + Addr: dbConfig.Host + ":" + dbConfig.Port, + DBName: dbConfig.Name, + Params: map[string]string{ + "charset": "utf8mb4", + "parseTime": "True", + "loc": "Local", + }, + } + return cfg.FormatDSN() +} + // initMariaDB opens a MariaDB connection and runs model migrations. func initMariaDB() error { dbConfig := config.GetDBConfigFromJSON() - - dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", - dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name) + dsn := buildMariaDBDSN(dbConfig) var gormLogger logger.Interface if config.IsDebug() { diff --git a/database/migrate.go b/database/migrate.go index c64c37c1..e041e214 100644 --- a/database/migrate.go +++ b/database/migrate.go @@ -48,8 +48,7 @@ func openSQLite(dbPath string) (*gorm.DB, error) { // openMariaDB opens a MariaDB connection for migration. func openMariaDB() (*gorm.DB, error) { dbConfig := config.GetDBConfigFromJSON() - dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", - dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name) + dsn := buildMariaDBDSN(dbConfig) return gorm.Open(mysql.Open(dsn), &gorm.Config{Logger: logger.Discard}) } diff --git a/database/migrate_test.go b/database/migrate_test.go new file mode 100644 index 00000000..90ed7ed0 --- /dev/null +++ b/database/migrate_test.go @@ -0,0 +1,181 @@ +package database + +import ( + "path/filepath" + "testing" + + "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/xray" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// openTestSQLiteDB opens an in-memory or file-based SQLite database for testing. +func openTestSQLiteDB(t *testing.T, dbPath string) *gorm.DB { + t.Helper() + gdb, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{Logger: logger.Discard}) + if err != nil { + t.Fatalf("failed to open test SQLite DB: %v", err) + } + t.Cleanup(func() { + sqlDB, _ := gdb.DB() + if sqlDB != nil { + sqlDB.Close() + } + }) + return gdb +} + +// createTestTables runs AutoMigrate on the given DB for all models. +func createTestTables(t *testing.T, gdb *gorm.DB) { + t.Helper() + models := []any{ + &model.User{}, + &model.Inbound{}, + &model.OutboundTraffics{}, + &model.Setting{}, + &model.InboundClientIps{}, + &xray.ClientTraffic{}, + &model.HistoryOfSeeders{}, + } + for _, m := range models { + if err := gdb.AutoMigrate(m); err != nil { + t.Fatalf("AutoMigrate failed: %v", err) + } + } +} + +func TestMigrateAllTables_EmoprySource(t *testing.T) { + srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db")) + dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db")) + + createTestTables(t, srcDB) + createTestTables(t, dstDB) + + err := migrateAllTables(srcDB, dstDB) + if err != nil { + t.Fatalf("migrateAllTables on empty source should succeed: %v", err) + } + + // Verify destination is still empty + for _, name := range tableNames() { + var count int64 + dstDB.Table(name).Count(&count) + if count != 0 { + t.Errorf("table %s should be empty, got %d rows", name, count) + } + } +} + +func TestMigrateAllTables_WithData(t *testing.T) { + srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db")) + dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db")) + + createTestTables(t, srcDB) + createTestTables(t, dstDB) + + // Insert test data into source + srcDB.Create(&model.User{Username: "testuser", Password: "testpass", Role: "admin"}) + srcDB.Create(&model.Setting{Key: "testkey", Value: "testvalue"}) + + err := migrateAllTables(srcDB, dstDB) + if err != nil { + t.Fatalf("migrateAllTables failed: %v", err) + } + + // Verify data was copied + var userCount int64 + dstDB.Table("users").Count(&userCount) + if userCount != 1 { + t.Errorf("expected 1 user in dst, got %d", userCount) + } + + var settingCount int64 + dstDB.Table("settings").Count(&settingCount) + if settingCount != 1 { + t.Errorf("expected 1 setting in dst, got %d", settingCount) + } +} + +func TestMigrateAllTables_OverwritesExisting(t *testing.T) { + srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db")) + dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db")) + + createTestTables(t, srcDB) + createTestTables(t, dstDB) + + // Insert existing data in destination that should be cleared + dstDB.Create(&model.User{Username: "olduser", Password: "oldpass", Role: "admin"}) + dstDB.Create(&model.Setting{Key: "oldkey", Value: "oldvalue"}) + + // Insert new data in source + srcDB.Create(&model.User{Username: "newuser", Password: "newpass", Role: "admin"}) + + err := migrateAllTables(srcDB, dstDB) + if err != nil { + t.Fatalf("migrateAllTables failed: %v", err) + } + + // Verify old data was replaced + var userCount int64 + dstDB.Table("users").Count(&userCount) + if userCount != 1 { + t.Errorf("expected 1 user in dst after overwrite, got %d", userCount) + } + + var user model.User + dstDB.Table("users").First(&user) + if user.Username != "newuser" { + t.Errorf("expected username 'newuser', got '%s'", user.Username) + } + + // Settings should be empty since source has no settings + var settingCount int64 + dstDB.Table("settings").Count(&settingCount) + if settingCount != 0 { + t.Errorf("expected 0 settings in dst after overwrite, got %d", settingCount) + } +} + +func TestMigrateTable_Generic(t *testing.T) { + srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db")) + dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db")) + + createTestTables(t, srcDB) + createTestTables(t, dstDB) + + // Insert test users + srcDB.Create(&model.User{Username: "user1", Password: "pass1", Role: "admin"}) + srcDB.Create(&model.User{Username: "user2", Password: "pass2", Role: "admin"}) + + count, err := migrateTable[model.User](srcDB, dstDB, "users") + if err != nil { + t.Fatalf("migrateTable failed: %v", err) + } + if count != 2 { + t.Errorf("expected 2 rows migrated, got %d", count) + } + + var dstCount int64 + dstDB.Table("users").Count(&dstCount) + if dstCount != 2 { + t.Errorf("expected 2 users in dst, got %d", dstCount) + } +} + +func TestMigrateTable_EmptyTable(t *testing.T) { + srcDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "src.db")) + dstDB := openTestSQLiteDB(t, filepath.Join(t.TempDir(), "dst.db")) + + createTestTables(t, srcDB) + createTestTables(t, dstDB) + + count, err := migrateTable[model.User](srcDB, dstDB, "users") + if err != nil { + t.Fatalf("migrateTable on empty table should succeed: %v", err) + } + if count != 0 { + t.Errorf("expected 0 rows migrated, got %d", count) + } +} diff --git a/main.go b/main.go index 2e58c2a9..3ba5602b 100644 --- a/main.go +++ b/main.go @@ -401,24 +401,41 @@ func migrateDb() { fmt.Println("Migration done!") } -// migrateDbBetweenDrivers migrates data between SQLite and MariaDB based on the configured dbType. -func migrateDbBetweenDrivers() { - dbType := config.GetDBTypeFromJSON() - switch dbType { - case "mariadb": +// migrateDbBetweenDrivers migrates data between SQLite and MariaDB. +// The direction can be specified via --direction flag, otherwise it falls back to dbType from config. +func migrateDbBetweenDrivers(direction string) { + switch direction { + case "sqlite-to-mariadb": fmt.Println("Migrating data from SQLite to MariaDB...") if err := database.MigrateSQLiteToMariaDB(); err != nil { log.Fatal("Migration failed: ", err) } fmt.Println("Migration to MariaDB completed successfully.") - case "sqlite": + case "mariadb-to-sqlite": fmt.Println("Migrating data from MariaDB to SQLite...") if err := database.MigrateMariaDBToSQLite(); err != nil { log.Fatal("Migration failed: ", err) } fmt.Println("Migration to SQLite completed successfully.") default: - log.Fatalf("Unknown dbType: %s", dbType) + // Fall back to inferring from dbType config + dbType := config.GetDBTypeFromJSON() + switch dbType { + case "mariadb": + fmt.Println("Migrating data from SQLite to MariaDB...") + if err := database.MigrateSQLiteToMariaDB(); err != nil { + log.Fatal("Migration failed: ", err) + } + fmt.Println("Migration to MariaDB completed successfully.") + case "sqlite": + fmt.Println("Migrating data from MariaDB to SQLite...") + if err := database.MigrateMariaDBToSQLite(); err != nil { + log.Fatal("Migration failed: ", err) + } + fmt.Println("Migration to SQLite completed successfully.") + default: + log.Fatalf("Unknown dbType: %s", dbType) + } } } @@ -474,12 +491,18 @@ func main() { var dbUser string var dbPassword string var dbName string + var showDbType bool settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)") settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host") settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port") settingCmd.StringVar(&dbUser, "dbUser", "", "Set MariaDB username") settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password") settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name") + settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit") + + migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError) + var migrateDirection string + migrateDbCmd.StringVar(&migrateDirection, "direction", "", "Migration direction: sqlite-to-mariadb or mariadb-to-sqlite") // Allow dbPassword to be passed via env var to avoid leaking it in process args if p := os.Getenv("XUI_DB_PASSWORD"); p != "" { @@ -514,13 +537,22 @@ func main() { case "migrate": migrateDb() case "migrate-db": - migrateDbBetweenDrivers() + err := migrateDbCmd.Parse(os.Args[2:]) + if err != nil { + fmt.Println(err) + return + } + migrateDbBetweenDrivers(migrateDirection) case "setting": err := settingCmd.Parse(os.Args[2:]) if err != nil { fmt.Println(err) return } + if showDbType { + fmt.Println(config.GetDBTypeFromJSON()) + return + } if reset { resetSetting() } else { diff --git a/web/entity/entity.go b/web/entity/entity.go index 43c9b292..c163591a 100644 --- a/web/entity/entity.go +++ b/web/entity/entity.go @@ -109,7 +109,7 @@ type AllSetting struct { DBHost string `json:"dbHost" form:"dbHost"` DBPort string `json:"dbPort" form:"dbPort"` DBUser string `json:"dbUser" form:"dbUser"` - DBPassword string `json:"dbPassword" form:"dbPassword"` + DBPassword string `json:"-" form:"dbPassword"` DBName string `json:"dbName" form:"dbName"` // Registration settings @@ -184,5 +184,27 @@ func (s *AllSetting) CheckValid() error { return common.NewError("time location not exist:", s.TimeLocation) } + // Validate database settings + if s.DBType != "" && s.DBType != "sqlite" && s.DBType != "mariadb" { + return common.NewError("db type must be sqlite or mariadb, got:", s.DBType) + } + if s.DBType == "mariadb" { + if s.DBHost == "" { + return common.NewError("db host is required for MariaDB") + } + if s.DBPort != "" { + port := 0 + for _, c := range s.DBPort { + if c < '0' || c > '9' { + return common.NewError("db port is not a valid number:", s.DBPort) + } + port = port*10 + int(c-'0') + } + if port <= 0 || port > math.MaxUint16 { + return common.NewError("db port is not a valid port:", s.DBPort) + } + } + } + return nil } diff --git a/web/service/setting.go b/web/service/setting.go index 21776655..35cdfc24 100644 --- a/web/service/setting.go +++ b/web/service/setting.go @@ -998,6 +998,10 @@ func (s *SettingService) UpdateAllSetting(allSetting *entity.AllSetting) error { fieldV := v.FieldByName(field.Name) settings[key] = fmt.Sprint(fieldV.Interface()) } + // DBPassword uses json:"-" to avoid leaking to frontend, handle it via form tag + if allSetting.DBPassword != "" { + settings["dbPassword"] = allSetting.DBPassword + } return saveSettings(settings) } diff --git a/x-ui.sh b/x-ui.sh index 6f8f76e0..1e80b10c 100644 --- a/x-ui.sh +++ b/x-ui.sh @@ -2202,15 +2202,10 @@ show_usage() { └────────────────────────────────────────────────────────────────┘" } -# Read dbType from /etc/x-ui/x-ui.json +# Read dbType from /etc/x-ui/x-ui.json using the Go binary read_json_dbtype() { - local json_path="/etc/x-ui/x-ui.json" - if [ ! -f "$json_path" ]; then - echo "sqlite" - return - fi - # Try nested format first (other.dbType) - local db_type=$(grep -o '"dbType"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" | head -1 | sed 's/.*"dbType"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') + local db_type + db_type=$(${xui_folder}/x-ui setting -showDbType 2>/dev/null) if [ -z "$db_type" ]; then echo "sqlite" else @@ -2224,9 +2219,15 @@ db_show_status() { echo -e "${green}当前数据库类型: ${current_type}${plain}" if [ "$current_type" = "mariadb" ]; then local json_path="/etc/x-ui/x-ui.json" - local host=$(grep -o '"dbHost"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | head -1 | sed 's/.*"dbHost"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') - local port=$(grep -o '"dbPort"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | head -1 | sed 's/.*"dbPort"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') - local dbname=$(grep -o '"dbName"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | head -1 | sed 's/.*"dbName"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') + if command -v jq >/dev/null 2>&1; then + local host=$(jq -r '.other.dbHost // "127.0.0.1"' "$json_path" 2>/dev/null) + local port=$(jq -r '.other.dbPort // "3306"' "$json_path" 2>/dev/null) + local dbname=$(jq -r '.other.dbName // "3xui"' "$json_path" 2>/dev/null) + else + local host=$(grep -o '"dbHost"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | tail -1 | sed 's/.*"\([^"]*\)"$/\1/') + local port=$(grep -o '"dbPort"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | tail -1 | sed 's/.*"\([^"]*\)"$/\1/') + local dbname=$(grep -o '"dbName"[[:space:]]*:[[:space:]]*"[^"]*"' "$json_path" 2>/dev/null | tail -1 | sed 's/.*"\([^"]*\)"$/\1/') + fi echo -e "${green}MariaDB 主机: ${host:-127.0.0.1}:${port:-3306}${plain}" echo -e "${green}数据库名: ${dbname:-3xui}${plain}" fi @@ -2268,17 +2269,17 @@ db_switch_to_mariadb() { db_name=${db_name:-3xui} echo -e "${green}正在配置 MariaDB 连接...${plain}" - XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting -dbType mariadb -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbName "$db_name" >/dev/null 2>&1 + XUI_DB_PASSWORD="$db_pass" ${xui_folder}/x-ui setting -dbHost "$db_host" -dbPort "$db_port" -dbUser "$db_user" -dbName "$db_name" >/dev/null 2>&1 echo -e "${green}正在迁移数据从 SQLite 到 MariaDB...${plain}" - ${xui_folder}/x-ui migrate-db + ${xui_folder}/x-ui migrate-db -direction sqlite-to-mariadb if [ $? -eq 0 ]; then echo -e "${green}数据库切换成功,正在重启面板...${plain}" + ${xui_folder}/x-ui setting -dbType mariadb >/dev/null 2>&1 restart else - echo -e "${red}数据迁移失败,正在回滚到 SQLite...${plain}" - ${xui_folder}/x-ui setting -dbType sqlite >/dev/null 2>&1 + echo -e "${red}数据迁移失败,保持 SQLite 不变${plain}" restart fi } @@ -2293,14 +2294,15 @@ db_switch_to_sqlite() { fi echo -e "${green}正在迁移数据从 MariaDB 到 SQLite...${plain}" - ${xui_folder}/x-ui setting -dbType sqlite >/dev/null 2>&1 - ${xui_folder}/x-ui migrate-db + ${xui_folder}/x-ui migrate-db -direction mariadb-to-sqlite if [ $? -eq 0 ]; then echo -e "${green}数据库切换成功,正在重启面板...${plain}" + ${xui_folder}/x-ui setting -dbType sqlite >/dev/null 2>&1 restart else - echo -e "${red}数据迁移失败${plain}" + echo -e "${red}数据迁移失败,保持 MariaDB 不变${plain}" + db_menu fi }