feat: refactor InitDB to be driver-agnostic with SQLite/MariaDB dispatch

This commit is contained in:
Sora39831 2026-04-03 09:26:42 +08:00
parent e3d1a94d89
commit 283892c548
2 changed files with 83 additions and 18 deletions

View file

@ -5,6 +5,7 @@ package database
import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"log"
@ -17,6 +18,7 @@ import (
"github.com/mhsanaei/3x-ui/v2/util/crypto"
"github.com/mhsanaei/3x-ui/v2/xray"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@ -135,7 +137,57 @@ func isTableEmpty(tableName string) (bool, error) {
}
// InitDB sets up the database connection, migrates models, and runs seeders.
func InitDB(dbPath string) error {
// It reads the dbType from the JSON config to determine whether to use SQLite or MariaDB.
func InitDB() error {
dbType := config.GetDBTypeFromJSON()
var err error
switch dbType {
case "mariadb":
err = initMariaDB()
default:
err = initSQLite(config.GetDBPath())
}
if err != nil {
return err
}
if err := initModels(); err != nil {
return err
}
if err := initUser(); err != nil {
return err
}
isUsersEmpty, err := isTableEmpty("users")
if err != nil {
return err
}
return runSeeders(isUsersEmpty)
}
// InitDBWithPath is a convenience function for tests and migrations that need
// to open a specific SQLite file.
func InitDBWithPath(dbPath string) error {
if err := initSQLite(dbPath); err != nil {
return err
}
if err := initModels(); err != nil {
return err
}
if err := initUser(); err != nil {
return err
}
isUsersEmpty, err := isTableEmpty("users")
if err != nil {
return err
}
return runSeeders(isUsersEmpty)
}
// initSQLite opens a SQLite database connection and runs model migrations.
func initSQLite(dbPath string) error {
dir := path.Dir(dbPath)
err := os.MkdirAll(dir, fs.ModePerm)
if err != nil {
@ -158,19 +210,33 @@ func InitDB(dbPath string) error {
return err
}
if err := initModels(); err != nil {
return err
return nil
}
// initMariaDB opens a MariaDB connection and runs model migrations.
func initMariaDB() error {
dbConfig := config.GetDBConfigFromJSON()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name)
var gormLogger logger.Interface
if config.IsDebug() {
gormLogger = logger.Default
} else {
gormLogger = logger.Discard
}
isUsersEmpty, err := isTableEmpty("users")
var err error
c := &gorm.Config{
Logger: gormLogger,
}
db, err = gorm.Open(mysql.Open(dsn), c)
if err != nil {
return err
}
if err := initUser(); err != nil {
return err
}
return runSeeders(isUsersEmpty)
return nil
}
// CloseDB closes the database connection if it exists.
@ -207,13 +273,12 @@ func IsSQLiteDB(file io.ReaderAt) (bool, error) {
}
// Checkpoint performs a WAL checkpoint on the SQLite database to ensure data consistency.
// For MariaDB, this is a no-op.
func Checkpoint() error {
// Update WAL
err := db.Exec("PRAGMA wal_checkpoint;").Error
if err != nil {
return err
if config.GetDBTypeFromJSON() != "sqlite" {
return nil
}
return nil
return db.Exec("PRAGMA wal_checkpoint;").Error
}
// ValidateSQLiteDB opens the provided sqlite DB path with a throw-away connection

View file

@ -14,7 +14,7 @@ func setupTestDB(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DEBUG", "")
dbPath := filepath.Join(tmpDir, "test.db")
if err := InitDB(dbPath); err != nil {
if err := InitDBWithPath(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
t.Cleanup(func() {
@ -25,7 +25,7 @@ func setupTestDB(t *testing.T) {
func TestIsSQLiteDB_ValidFile(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "valid.db")
if err := InitDB(dbPath); err != nil {
if err := InitDBWithPath(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
defer CloseDB()
@ -137,13 +137,13 @@ func TestInitDB_Idempotent(t *testing.T) {
dbPath := filepath.Join(tmpDir, "idempotent.db")
// First init
if err := InitDB(dbPath); err != nil {
if err := InitDBWithPath(dbPath); err != nil {
t.Fatalf("first InitDB failed: %v", err)
}
CloseDB()
// Second init on the same file should not fail
if err := InitDB(dbPath); err != nil {
if err := InitDBWithPath(dbPath); err != nil {
t.Fatalf("second InitDB failed: %v", err)
}
defer CloseDB()
@ -159,7 +159,7 @@ func TestInitDB_Idempotent(t *testing.T) {
func TestValidateSQLiteDB_ValidDB(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "valid.db")
if err := InitDB(dbPath); err != nil {
if err := InitDBWithPath(dbPath); err != nil {
t.Fatalf("InitDB failed: %v", err)
}
CloseDB()