Refactor database initialization

This commit is contained in:
mhsanaei 2024-07-14 01:22:02 +02:00
parent 60cb328698
commit dfe0bbd371

View file

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io" "io"
"io/fs" "io/fs"
"log"
"os" "os"
"path" "path"
@ -18,54 +19,51 @@ import (
var db *gorm.DB var db *gorm.DB
var initializers = []func() error{ const (
initUser, defaultUsername = "admin"
initInbound, defaultPassword = "admin"
initOutbound, defaultSecret = ""
initSetting, )
initInboundClientIps,
initClientTraffic, func initModels() error {
models := []interface{}{
&model.User{},
&model.Inbound{},
&model.OutboundTraffics{},
&model.Setting{},
&model.InboundClientIps{},
&xray.ClientTraffic{},
}
for _, model := range models {
if err := db.AutoMigrate(model); err != nil {
log.Printf("Error auto migrating model: %v", err)
return err
}
}
return nil
} }
func initUser() error { func initUser() error {
err := db.AutoMigrate(&model.User{}) empty, err := isTableEmpty("users")
if err != nil { if err != nil {
log.Printf("Error checking if users table is empty: %v", err)
return err return err
} }
var count int64 if empty {
err = db.Model(&model.User{}).Count(&count).Error
if err != nil {
return err
}
if count == 0 {
user := &model.User{ user := &model.User{
Username: "admin", Username: defaultUsername,
Password: "admin", Password: defaultPassword,
LoginSecret: "", LoginSecret: defaultSecret,
} }
return db.Create(user).Error return db.Create(user).Error
} }
return nil return nil
} }
func initInbound() error { func isTableEmpty(tableName string) (bool, error) {
return db.AutoMigrate(&model.Inbound{}) var count int64
} err := db.Table(tableName).Count(&count).Error
return count == 0, err
func initOutbound() error {
return db.AutoMigrate(&model.OutboundTraffics{})
}
func initSetting() error {
return db.AutoMigrate(&model.Setting{})
}
func initInboundClientIps() error {
return db.AutoMigrate(&model.InboundClientIps{})
}
func initClientTraffic() error {
return db.AutoMigrate(&xray.ClientTraffic{})
} }
func InitDB(dbPath string) error { func InitDB(dbPath string) error {
@ -91,15 +89,27 @@ func InitDB(dbPath string) error {
return err return err
} }
for _, initialize := range initializers { if err := initModels(); err != nil {
if err := initialize(); err != nil {
return err return err
} }
if err := initUser(); err != nil {
return err
} }
return nil return nil
} }
func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}
func GetDB() *gorm.DB { func GetDB() *gorm.DB {
return db return db
} }