diff --git a/database/db.go b/database/db.go index c75953f0..300a73c0 100644 --- a/database/db.go +++ b/database/db.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "io/fs" + "log" "os" "path" @@ -18,54 +19,51 @@ import ( var db *gorm.DB -var initializers = []func() error{ - initUser, - initInbound, - initOutbound, - initSetting, - initInboundClientIps, - initClientTraffic, +const ( + defaultUsername = "admin" + defaultPassword = "admin" + defaultSecret = "" +) + +func initModels() error { + models := []interface{}{ + &model.User{}, + &model.Inbound{}, + &model.OutboundTraffics{}, + &model.Setting{}, + &model.InboundClientIps{}, + &xray.ClientTraffic{}, + } + for _, model := range models { + if err := db.AutoMigrate(model); err != nil { + log.Printf("Error auto migrating model: %v", err) + return err + } + } + return nil } func initUser() error { - err := db.AutoMigrate(&model.User{}) + empty, err := isTableEmpty("users") if err != nil { + log.Printf("Error checking if users table is empty: %v", err) return err } - var count int64 - err = db.Model(&model.User{}).Count(&count).Error - if err != nil { - return err - } - if count == 0 { + if empty { user := &model.User{ - Username: "admin", - Password: "admin", - LoginSecret: "", + Username: defaultUsername, + Password: defaultPassword, + LoginSecret: defaultSecret, } return db.Create(user).Error } return nil } -func initInbound() error { - return db.AutoMigrate(&model.Inbound{}) -} - -func initOutbound() error { - return db.AutoMigrate(&model.OutboundTraffics{}) -} - -func initSetting() error { - return db.AutoMigrate(&model.Setting{}) -} - -func initInboundClientIps() error { - return db.AutoMigrate(&model.InboundClientIps{}) -} - -func initClientTraffic() error { - return db.AutoMigrate(&xray.ClientTraffic{}) +func isTableEmpty(tableName string) (bool, error) { + var count int64 + err := db.Table(tableName).Count(&count).Error + return count == 0, err } func InitDB(dbPath string) error { @@ -91,15 +89,27 @@ func InitDB(dbPath string) error { return err } - for _, initialize := range initializers { - if err := initialize(); err != nil { - return err - } + if err := initModels(); err != nil { + return err + } + if err := initUser(); err != nil { + return err } return nil } +func CloseDB() error { + if db != nil { + sqlDB, err := db.DB() + if err != nil { + return err + } + return sqlDB.Close() + } + return nil +} + func GetDB() *gorm.DB { return db }