mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-06 21:24:10 +00:00
feat: refactor InitDB to be driver-agnostic with SQLite/MariaDB dispatch
This commit is contained in:
parent
e3d1a94d89
commit
283892c548
2 changed files with 83 additions and 18 deletions
|
|
@ -5,6 +5,7 @@ package database
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
|
|
@ -17,6 +18,7 @@ import (
|
|||
"github.com/mhsanaei/3x-ui/v2/util/crypto"
|
||||
"github.com/mhsanaei/3x-ui/v2/xray"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
|
@ -135,7 +137,57 @@ func isTableEmpty(tableName string) (bool, error) {
|
|||
}
|
||||
|
||||
// InitDB sets up the database connection, migrates models, and runs seeders.
|
||||
func InitDB(dbPath string) error {
|
||||
// It reads the dbType from the JSON config to determine whether to use SQLite or MariaDB.
|
||||
func InitDB() error {
|
||||
dbType := config.GetDBTypeFromJSON()
|
||||
|
||||
var err error
|
||||
switch dbType {
|
||||
case "mariadb":
|
||||
err = initMariaDB()
|
||||
default:
|
||||
err = initSQLite(config.GetDBPath())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initModels(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isUsersEmpty, err := isTableEmpty("users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runSeeders(isUsersEmpty)
|
||||
}
|
||||
|
||||
// InitDBWithPath is a convenience function for tests and migrations that need
|
||||
// to open a specific SQLite file.
|
||||
func InitDBWithPath(dbPath string) error {
|
||||
if err := initSQLite(dbPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := initModels(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := initUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
isUsersEmpty, err := isTableEmpty("users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runSeeders(isUsersEmpty)
|
||||
}
|
||||
|
||||
// initSQLite opens a SQLite database connection and runs model migrations.
|
||||
func initSQLite(dbPath string) error {
|
||||
dir := path.Dir(dbPath)
|
||||
err := os.MkdirAll(dir, fs.ModePerm)
|
||||
if err != nil {
|
||||
|
|
@ -158,19 +210,33 @@ func InitDB(dbPath string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := initModels(); err != nil {
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// initMariaDB opens a MariaDB connection and runs model migrations.
|
||||
func initMariaDB() error {
|
||||
dbConfig := config.GetDBConfigFromJSON()
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name)
|
||||
|
||||
var gormLogger logger.Interface
|
||||
if config.IsDebug() {
|
||||
gormLogger = logger.Default
|
||||
} else {
|
||||
gormLogger = logger.Discard
|
||||
}
|
||||
|
||||
isUsersEmpty, err := isTableEmpty("users")
|
||||
var err error
|
||||
c := &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
}
|
||||
db, err = gorm.Open(mysql.Open(dsn), c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
return runSeeders(isUsersEmpty)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseDB closes the database connection if it exists.
|
||||
|
|
@ -207,13 +273,12 @@ func IsSQLiteDB(file io.ReaderAt) (bool, error) {
|
|||
}
|
||||
|
||||
// Checkpoint performs a WAL checkpoint on the SQLite database to ensure data consistency.
|
||||
// For MariaDB, this is a no-op.
|
||||
func Checkpoint() error {
|
||||
// Update WAL
|
||||
err := db.Exec("PRAGMA wal_checkpoint;").Error
|
||||
if err != nil {
|
||||
return err
|
||||
if config.GetDBTypeFromJSON() != "sqlite" {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return db.Exec("PRAGMA wal_checkpoint;").Error
|
||||
}
|
||||
|
||||
// ValidateSQLiteDB opens the provided sqlite DB path with a throw-away connection
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ func setupTestDB(t *testing.T) {
|
|||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DEBUG", "")
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
if err := InitDBWithPath(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
|
|
@ -25,7 +25,7 @@ func setupTestDB(t *testing.T) {
|
|||
func TestIsSQLiteDB_ValidFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "valid.db")
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
if err := InitDBWithPath(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
defer CloseDB()
|
||||
|
|
@ -137,13 +137,13 @@ func TestInitDB_Idempotent(t *testing.T) {
|
|||
dbPath := filepath.Join(tmpDir, "idempotent.db")
|
||||
|
||||
// First init
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
if err := InitDBWithPath(dbPath); err != nil {
|
||||
t.Fatalf("first InitDB failed: %v", err)
|
||||
}
|
||||
CloseDB()
|
||||
|
||||
// Second init on the same file should not fail
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
if err := InitDBWithPath(dbPath); err != nil {
|
||||
t.Fatalf("second InitDB failed: %v", err)
|
||||
}
|
||||
defer CloseDB()
|
||||
|
|
@ -159,7 +159,7 @@ func TestInitDB_Idempotent(t *testing.T) {
|
|||
func TestValidateSQLiteDB_ValidDB(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "valid.db")
|
||||
if err := InitDB(dbPath); err != nil {
|
||||
if err := InitDBWithPath(dbPath); err != nil {
|
||||
t.Fatalf("InitDB failed: %v", err)
|
||||
}
|
||||
CloseDB()
|
||||
|
|
|
|||
Loading…
Reference in a new issue