diff --git a/database/db.go b/database/db.go index 48b5538c..e6bfb519 100644 --- a/database/db.go +++ b/database/db.go @@ -5,6 +5,7 @@ package database import ( "bytes" "errors" + "fmt" "io" "io/fs" "log" @@ -17,6 +18,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/util/crypto" "github.com/mhsanaei/3x-ui/v2/xray" + "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -135,7 +137,57 @@ func isTableEmpty(tableName string) (bool, error) { } // InitDB sets up the database connection, migrates models, and runs seeders. -func InitDB(dbPath string) error { +// It reads the dbType from the JSON config to determine whether to use SQLite or MariaDB. +func InitDB() error { + dbType := config.GetDBTypeFromJSON() + + var err error + switch dbType { + case "mariadb": + err = initMariaDB() + default: + err = initSQLite(config.GetDBPath()) + } + if err != nil { + return err + } + + if err := initModels(); err != nil { + return err + } + + if err := initUser(); err != nil { + return err + } + + isUsersEmpty, err := isTableEmpty("users") + if err != nil { + return err + } + return runSeeders(isUsersEmpty) +} + +// InitDBWithPath is a convenience function for tests and migrations that need +// to open a specific SQLite file. +func InitDBWithPath(dbPath string) error { + if err := initSQLite(dbPath); err != nil { + return err + } + if err := initModels(); err != nil { + return err + } + if err := initUser(); err != nil { + return err + } + isUsersEmpty, err := isTableEmpty("users") + if err != nil { + return err + } + return runSeeders(isUsersEmpty) +} + +// initSQLite opens a SQLite database connection and runs model migrations. +func initSQLite(dbPath string) error { dir := path.Dir(dbPath) err := os.MkdirAll(dir, fs.ModePerm) if err != nil { @@ -158,19 +210,33 @@ func InitDB(dbPath string) error { return err } - if err := initModels(); err != nil { - return err + return nil +} + +// initMariaDB opens a MariaDB connection and runs model migrations. +func initMariaDB() error { + dbConfig := config.GetDBConfigFromJSON() + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", + dbConfig.User, dbConfig.Password, dbConfig.Host, dbConfig.Port, dbConfig.Name) + + var gormLogger logger.Interface + if config.IsDebug() { + gormLogger = logger.Default + } else { + gormLogger = logger.Discard } - isUsersEmpty, err := isTableEmpty("users") + var err error + c := &gorm.Config{ + Logger: gormLogger, + } + db, err = gorm.Open(mysql.Open(dsn), c) if err != nil { return err } - if err := initUser(); err != nil { - return err - } - return runSeeders(isUsersEmpty) + return nil } // CloseDB closes the database connection if it exists. @@ -207,13 +273,12 @@ func IsSQLiteDB(file io.ReaderAt) (bool, error) { } // Checkpoint performs a WAL checkpoint on the SQLite database to ensure data consistency. +// For MariaDB, this is a no-op. func Checkpoint() error { - // Update WAL - err := db.Exec("PRAGMA wal_checkpoint;").Error - if err != nil { - return err + if config.GetDBTypeFromJSON() != "sqlite" { + return nil } - return nil + return db.Exec("PRAGMA wal_checkpoint;").Error } // ValidateSQLiteDB opens the provided sqlite DB path with a throw-away connection diff --git a/database/db_test.go b/database/db_test.go index 5bd42d7e..9cebdd20 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -14,7 +14,7 @@ func setupTestDB(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XUI_DEBUG", "") dbPath := filepath.Join(tmpDir, "test.db") - if err := InitDB(dbPath); err != nil { + if err := InitDBWithPath(dbPath); err != nil { t.Fatalf("InitDB failed: %v", err) } t.Cleanup(func() { @@ -25,7 +25,7 @@ func setupTestDB(t *testing.T) { func TestIsSQLiteDB_ValidFile(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "valid.db") - if err := InitDB(dbPath); err != nil { + if err := InitDBWithPath(dbPath); err != nil { t.Fatalf("InitDB failed: %v", err) } defer CloseDB() @@ -137,13 +137,13 @@ func TestInitDB_Idempotent(t *testing.T) { dbPath := filepath.Join(tmpDir, "idempotent.db") // First init - if err := InitDB(dbPath); err != nil { + if err := InitDBWithPath(dbPath); err != nil { t.Fatalf("first InitDB failed: %v", err) } CloseDB() // Second init on the same file should not fail - if err := InitDB(dbPath); err != nil { + if err := InitDBWithPath(dbPath); err != nil { t.Fatalf("second InitDB failed: %v", err) } defer CloseDB() @@ -159,7 +159,7 @@ func TestInitDB_Idempotent(t *testing.T) { func TestValidateSQLiteDB_ValidDB(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "valid.db") - if err := InitDB(dbPath); err != nil { + if err := InitDBWithPath(dbPath); err != nil { t.Fatalf("InitDB failed: %v", err) } CloseDB()