From 3850e2f070bb3c1227b78dcd1e8a34a89a9e9906 Mon Sep 17 00:00:00 2001 From: Ali Golzar <57574919+aliglzr@users.noreply.github.com> Date: Wed, 21 May 2025 13:34:38 +0330 Subject: [PATCH] feat: Add MySQL database support (#3024) * feat: Add MySQL database support - Add MySQL database support with environment-based configuration - Fix MySQL compatibility issue with 'key' column name - Maintain SQLite as default database - Add proper validation for MySQL configuration - Test and verify compatibility with existing database - Replaced raw SQL queries using JSON_EACH functions with standard GORM queries - Modified functions to handle JSON parsing in Go code instead of database since JSON_EACH is not available on MySQL or MariaDB: - getAllEmails() - GetClientTrafficByID() - getFallbackMaster() - MigrationRemoveOrphanedTraffics() The system now supports both MySQL and SQLite databases, with SQLite remaining as the default option. MySQL connection is only used when explicitly configured through environment variables. * refactor: prefix env variables of database with XUI_ to support direct environment usage without .env file All database configuration environment variables now start with the XUI_ prefix to avoid conflicts and allow configuration via system-level environment variables, not just the .env file. --- .env.example | 8 +++ config/config.go | 49 ++++++++++++++++++ database/db.go | 38 ++++++++++---- database/model/model.go | 2 +- go.mod | 2 + go.sum | 5 ++ main.go | 1 + sub/subService.go | 74 +++++++++++++++++++++------ web/service/inbound.go | 108 +++++++++++++++++++++++++++++++--------- web/service/setting.go | 4 +- 10 files changed, 239 insertions(+), 52 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..228fde76 --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +XUI_DB_CONNECTION=sqlite + +# If DB connection is "mysql" +# XUI_DB_HOST=127.0.0.1 +# XUI_DB_PORT=3306 +# XUI_DB_DATABASE=xui +# XUI_DB_USERNAME=root +# XUI_DB_PASSWORD= \ No newline at end of file diff --git a/config/config.go b/config/config.go index 70be5ae6..942de7fb 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( _ "embed" "fmt" + "log" "os" "strings" ) @@ -62,7 +63,55 @@ func GetDBFolderPath() string { return dbFolderPath } +// DatabaseConfig holds the database configuration +type DatabaseConfig struct { + Connection string + Host string + Port string + Database string + Username string + Password string +} + +// GetDatabaseConfig returns the database configuration from environment variables +func GetDatabaseConfig() (*DatabaseConfig, error) { + config := &DatabaseConfig{ + Connection: strings.ToLower(os.Getenv("XUI_DB_CONNECTION")), + Host: os.Getenv("XUI_DB_HOST"), + Port: os.Getenv("XUI_DB_PORT"), + Database: os.Getenv("XUI_DB_DATABASE"), + Username: os.Getenv("XUI_DB_USERNAME"), + Password: os.Getenv("XUI_DB_PASSWORD"), + } + + if config.Connection == "mysql" { + if config.Host == "" || config.Database == "" || config.Username == "" { + return nil, fmt.Errorf("missing required MySQL configuration: host, database, and username are required") + } + if config.Port == "" { + config.Port = "3306" + } + } + + return config, nil +} + func GetDBPath() string { + config, err := GetDatabaseConfig() + if err != nil { + log.Fatalf("Error getting database config: %v", err) + } + + if config.Connection == "mysql" { + return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local", + config.Username, + config.Password, + config.Host, + config.Port, + config.Database) + } + + // Connection is sqlite return fmt.Sprintf("%s/%s.db", GetDBFolderPath(), GetName()) } diff --git a/database/db.go b/database/db.go index c72d28cf..05fc64cb 100644 --- a/database/db.go +++ b/database/db.go @@ -9,14 +9,15 @@ import ( "path" "slices" + "gorm.io/driver/mysql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "x-ui/config" "x-ui/database/model" "x-ui/util/crypto" "x-ui/xray" - - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" ) var db *gorm.DB @@ -114,12 +115,22 @@ func isTableEmpty(tableName string) (bool, error) { } func InitDB(dbPath string) error { - dir := path.Dir(dbPath) - err := os.MkdirAll(dir, fs.ModePerm) + dbConfig, err := config.GetDatabaseConfig() if err != nil { return err } + if dbConfig.Connection != "mysql" { + // Connection is sqlite + // Need to create the directory if it doesn't exist + + dir := path.Dir(dbPath) + err = os.MkdirAll(dir, fs.ModePerm) + if err != nil { + return err + } + } + var gormLogger logger.Interface if config.IsDebug() { @@ -131,9 +142,18 @@ func InitDB(dbPath string) error { c := &gorm.Config{ Logger: gormLogger, } - db, err = gorm.Open(sqlite.Open(dbPath), c) - if err != nil { - return err + + if dbConfig.Connection == "mysql" { + db, err = gorm.Open(mysql.Open(dbPath), c) + if err != nil { + return err + } + } else { + // Connection is sqlite + db, err = gorm.Open(sqlite.Open(dbPath), c) + if err != nil { + return err + } } if err := initModels(); err != nil { diff --git a/database/model/model.go b/database/model/model.go index 2e7095d3..767e12a9 100644 --- a/database/model/model.go +++ b/database/model/model.go @@ -86,7 +86,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig { type Setting struct { Id int `json:"id" form:"id" gorm:"primaryKey;autoIncrement"` - Key string `json:"key" form:"key"` + Key string `json:"key" form:"key" gorm:"column:key"` Value string `json:"value" form:"value"` } diff --git a/go.mod b/go.mod index 1f545c64..d178a45c 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( golang.org/x/crypto v0.38.0 golang.org/x/text v0.25.0 google.golang.org/grpc v1.72.1 + gorm.io/driver/mysql v1.5.7 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 ) @@ -41,6 +42,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect diff --git a/go.sum b/go.sum index c6ae9c41..752c3807 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= @@ -261,8 +263,11 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk= diff --git a/main.go b/main.go index dcbd3c9e..a19e4760 100644 --- a/main.go +++ b/main.go @@ -105,6 +105,7 @@ func runWebServer() { default: server.Stop() subServer.Stop() + database.CloseDB() log.Println("Shutting down servers.") return } diff --git a/sub/subService.go b/sub/subService.go index 9f26c0e0..1c74d103 100644 --- a/sub/subService.go +++ b/sub/subService.go @@ -107,18 +107,30 @@ func (s *SubService) GetSubs(subId string, host string) ([]string, string, error func (s *SubService) getInboundsBySubId(subId string) ([]*model.Inbound, error) { db := database.GetDB() var inbounds []*model.Inbound - err := db.Model(model.Inbound{}).Preload("ClientStats").Where(`id in ( - SELECT DISTINCT inbounds.id - FROM inbounds, - JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client - WHERE - protocol in ('vmess','vless','trojan','shadowsocks') - AND JSON_EXTRACT(client.value, '$.subId') = ? AND enable = ? - )`, subId, true).Find(&inbounds).Error + err := db.Model(model.Inbound{}). + Preload("ClientStats"). + Where("protocol IN ? AND enable = ?", []string{"vmess", "vless", "trojan", "shadowsocks"}, true). + Find(&inbounds).Error if err != nil { return nil, err } - return inbounds, nil + + // Filter inbounds that have clients with matching subId + var filteredInbounds []*model.Inbound + for _, inbound := range inbounds { + clients, err := s.inboundService.GetClients(inbound) + if err != nil { + continue + } + for _, client := range clients { + if client.SubID == subId { + filteredInbounds = append(filteredInbounds, inbound) + break + } + } + } + + return filteredInbounds, nil } func (s *SubService) getClientTraffics(traffics []xray.ClientTraffic, email string) xray.ClientTraffic { @@ -132,25 +144,55 @@ func (s *SubService) getClientTraffics(traffics []xray.ClientTraffic, email stri func (s *SubService) getFallbackMaster(dest string, streamSettings string) (string, int, string, error) { db := database.GetDB() - var inbound *model.Inbound - err := db.Model(model.Inbound{}). - Where("JSON_TYPE(settings, '$.fallbacks') = 'array'"). - Where("EXISTS (SELECT * FROM json_each(settings, '$.fallbacks') WHERE json_extract(value, '$.dest') = ?)", dest). - Find(&inbound).Error + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Find(&inbounds).Error if err != nil { return "", 0, "", err } + // Find inbound with matching fallback dest + var masterInbound *model.Inbound + for _, inbound := range inbounds { + var settings map[string]any + err := json.Unmarshal([]byte(inbound.Settings), &settings) + if err != nil { + continue + } + + fallbacks, ok := settings["fallbacks"].([]any) + if !ok { + continue + } + + for _, fallback := range fallbacks { + f, ok := fallback.(map[string]any) + if !ok { + continue + } + if fallbackDest, ok := f["dest"].(string); ok && fallbackDest == dest { + masterInbound = inbound + break + } + } + if masterInbound != nil { + break + } + } + + if masterInbound == nil { + return "", 0, "", fmt.Errorf("no inbound found with fallback dest: %s", dest) + } + var stream map[string]any json.Unmarshal([]byte(streamSettings), &stream) var masterStream map[string]any - json.Unmarshal([]byte(inbound.StreamSettings), &masterStream) + json.Unmarshal([]byte(masterInbound.StreamSettings), &masterStream) stream["security"] = masterStream["security"] stream["tlsSettings"] = masterStream["tlsSettings"] stream["externalProxy"] = masterStream["externalProxy"] modifiedStream, _ := json.MarshalIndent(stream, "", " ") - return inbound.Listen, inbound.Port, string(modifiedStream), nil + return masterInbound.Listen, masterInbound.Port, string(modifiedStream), nil } func (s *SubService) getLink(inbound *model.Inbound, email string) string { diff --git a/web/service/inbound.go b/web/service/inbound.go index f2646dbb..17a65b65 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -87,15 +87,24 @@ func (s *InboundService) GetClients(inbound *model.Inbound) ([]model.Client, err func (s *InboundService) getAllEmails() ([]string, error) { db := database.GetDB() - var emails []string - err := db.Raw(` - SELECT JSON_EXTRACT(client.value, '$.email') - FROM inbounds, - JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client - `).Scan(&emails).Error + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Find(&inbounds).Error if err != nil { return nil, err } + + var emails []string + for _, inbound := range inbounds { + clients, err := s.GetClients(inbound) + if err != nil { + continue + } + for _, client := range clients { + if client.Email != "" { + emails = append(emails, client.Email) + } + } + } return emails, nil } @@ -1120,14 +1129,46 @@ func (s *InboundService) GetInboundTags() (string, error) { func (s *InboundService) MigrationRemoveOrphanedTraffics() { db := database.GetDB() - db.Exec(` - DELETE FROM client_traffics - WHERE email NOT IN ( - SELECT JSON_EXTRACT(client.value, '$.email') - FROM inbounds, - JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client - ) - `) + + // Get all inbounds + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Find(&inbounds).Error + if err != nil { + logger.Error("Failed to get inbounds:", err) + return + } + + // Collect all valid emails from inbounds + validEmails := make(map[string]bool) + for _, inbound := range inbounds { + clients, err := s.GetClients(inbound) + if err != nil { + continue + } + for _, client := range clients { + if client.Email != "" { + validEmails[client.Email] = true + } + } + } + + // Get all client traffics + var traffics []xray.ClientTraffic + err = db.Model(xray.ClientTraffic{}).Find(&traffics).Error + if err != nil { + logger.Error("Failed to get client traffics:", err) + return + } + + // Delete traffics with emails not in validEmails + for _, traffic := range traffics { + if !validEmails[traffic.Email] { + err = db.Delete(&traffic).Error + if err != nil { + logger.Error("Failed to delete orphaned traffic:", err) + } + } + } } func (s *InboundService) AddClientStat(tx *gorm.DB, inboundId int, client *model.Client) error { @@ -1789,19 +1830,38 @@ func (s *InboundService) GetClientTrafficByID(id string) ([]xray.ClientTraffic, db := database.GetDB() var traffics []xray.ClientTraffic - err := db.Model(xray.ClientTraffic{}).Where(`email IN( - SELECT JSON_EXTRACT(client.value, '$.email') as email - FROM inbounds, - JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client - WHERE - JSON_EXTRACT(client.value, '$.id') in (?) - )`, id).Find(&traffics).Error - + // First get all inbounds + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Find(&inbounds).Error if err != nil { - logger.Debug(err) return nil, err } - return traffics, err + + // Collect all emails that match the ID + var targetEmails []string + for _, inbound := range inbounds { + clients, err := s.GetClients(inbound) + if err != nil { + continue + } + for _, client := range clients { + if client.ID == id && client.Email != "" { + targetEmails = append(targetEmails, client.Email) + } + } + } + // Get traffics for the collected emails + if len(targetEmails) > 0 { + err = db.Model(xray.ClientTraffic{}). + Where("email IN ?", targetEmails). + Find(&traffics).Error + if err != nil { + logger.Debug(err) + return nil, err + } + } + + return traffics, nil } func (s *InboundService) SearchClientTraffic(query string) (traffic *xray.ClientTraffic, err error) { diff --git a/web/service/setting.go b/web/service/setting.go index 62d66c11..8fe3ec01 100644 --- a/web/service/setting.go +++ b/web/service/setting.go @@ -88,7 +88,7 @@ func (s *SettingService) GetDefaultJsonConfig() (any, error) { func (s *SettingService) GetAllSetting() (*entity.AllSetting, error) { db := database.GetDB() settings := make([]*model.Setting, 0) - err := db.Model(model.Setting{}).Not("key = ?", "xrayTemplateConfig").Find(&settings).Error + err := db.Model(model.Setting{}).Not("`key` = ?", "xrayTemplateConfig").Find(&settings).Error if err != nil { return nil, err } @@ -173,7 +173,7 @@ func (s *SettingService) ResetSettings() error { func (s *SettingService) getSetting(key string) (*model.Setting, error) { db := database.GetDB() setting := &model.Setting{} - err := db.Model(model.Setting{}).Where("key = ?", key).First(setting).Error + err := db.Model(model.Setting{}).Where("`key` = ?", key).First(setting).Error if err != nil { return nil, err }