mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-07 13:44:24 +00:00
feat: add node config and startup validation
This commit is contained in:
parent
735e1d8ef0
commit
36826706ec
3 changed files with 256 additions and 0 deletions
107
config/config.go
107
config/config.go
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -31,6 +32,20 @@ const (
|
||||||
Error LogLevel = "error"
|
Error LogLevel = "error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type NodeRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
NodeRoleMaster NodeRole = "master"
|
||||||
|
NodeRoleWorker NodeRole = "worker"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NodeConfig struct {
|
||||||
|
Role NodeRole
|
||||||
|
NodeID string
|
||||||
|
SyncIntervalSeconds int
|
||||||
|
TrafficFlushSeconds int
|
||||||
|
}
|
||||||
|
|
||||||
// GetVersion returns the version string of the 3x-ui application.
|
// GetVersion returns the version string of the 3x-ui application.
|
||||||
func GetVersion() string {
|
func GetVersion() string {
|
||||||
return strings.TrimSpace(version)
|
return strings.TrimSpace(version)
|
||||||
|
|
@ -106,6 +121,14 @@ func GetSettingPath() string {
|
||||||
return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName())
|
return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetSharedCachePath() string {
|
||||||
|
return filepath.Join(GetDBFolderPath(), "shared-cache.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTrafficPendingPath() string {
|
||||||
|
return filepath.Join(GetDBFolderPath(), "traffic-pending.json")
|
||||||
|
}
|
||||||
|
|
||||||
// GetLogFolder returns the path to the log folder based on environment variables or platform defaults.
|
// GetLogFolder returns the path to the log folder based on environment variables or platform defaults.
|
||||||
func GetLogFolder() string {
|
func GetLogFolder() string {
|
||||||
logFolderPath := os.Getenv("XUI_LOG_FOLDER")
|
logFolderPath := os.Getenv("XUI_LOG_FOLDER")
|
||||||
|
|
@ -125,6 +148,14 @@ var settingGroupAliases = map[string][]string{
|
||||||
"dbUser": {"databaseConnection", "other"},
|
"dbUser": {"databaseConnection", "other"},
|
||||||
"dbPassword": {"databaseConnection", "other"},
|
"dbPassword": {"databaseConnection", "other"},
|
||||||
"dbName": {"databaseConnection", "other"},
|
"dbName": {"databaseConnection", "other"},
|
||||||
|
"nodeRole": {"other"},
|
||||||
|
"nodeId": {"other"},
|
||||||
|
"syncInterval": {
|
||||||
|
"other",
|
||||||
|
},
|
||||||
|
"trafficFlushInterval": {
|
||||||
|
"other",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func readGroupedString(settings map[string]any, key string) string {
|
func readGroupedString(settings map[string]any, key string) string {
|
||||||
|
|
@ -143,6 +174,37 @@ func readGroupedString(settings map[string]any, key string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readGroupedInt(settings map[string]any, key string, fallback int) int {
|
||||||
|
readInt := func(value any) (int, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case float64:
|
||||||
|
return int(v), true
|
||||||
|
case int:
|
||||||
|
return v, true
|
||||||
|
case string:
|
||||||
|
i, err := strconv.Atoi(v)
|
||||||
|
if err == nil {
|
||||||
|
return i, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if groups, ok := settingGroupAliases[key]; ok {
|
||||||
|
for _, groupName := range groups {
|
||||||
|
if group, ok := settings[groupName].(map[string]any); ok {
|
||||||
|
if value, ok := readInt(group[key]); ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := readInt(settings[key]); ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
func settingsLayoutMeta() map[string]any {
|
func settingsLayoutMeta() map[string]any {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"layout": "按模块-用途来归类",
|
"layout": "按模块-用途来归类",
|
||||||
|
|
@ -250,6 +312,51 @@ func GetDBConfigFromJSON() DBConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetNodeConfigFromJSON() NodeConfig {
|
||||||
|
data, err := os.ReadFile(GetSettingPath())
|
||||||
|
if err != nil {
|
||||||
|
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
|
||||||
|
}
|
||||||
|
|
||||||
|
role := readGroupedString(settings, "nodeRole")
|
||||||
|
if role == "" {
|
||||||
|
role = string(NodeRoleMaster)
|
||||||
|
}
|
||||||
|
|
||||||
|
return NodeConfig{
|
||||||
|
Role: NodeRole(role),
|
||||||
|
NodeID: readGroupedString(settings, "nodeId"),
|
||||||
|
SyncIntervalSeconds: readGroupedInt(settings, "syncInterval", 30),
|
||||||
|
TrafficFlushSeconds: readGroupedInt(settings, "trafficFlushInterval", 10),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateNodeConfig(nodeCfg NodeConfig, dbCfg DBConfig) error {
|
||||||
|
switch nodeCfg.Role {
|
||||||
|
case NodeRoleMaster, NodeRoleWorker:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid nodeRole %q", nodeCfg.Role)
|
||||||
|
}
|
||||||
|
if nodeCfg.Role == NodeRoleWorker && nodeCfg.NodeID == "" {
|
||||||
|
return fmt.Errorf("worker mode requires nodeId")
|
||||||
|
}
|
||||||
|
if nodeCfg.Role == NodeRoleWorker && dbCfg.Type != "mariadb" {
|
||||||
|
return fmt.Errorf("worker mode requires mariadb")
|
||||||
|
}
|
||||||
|
if nodeCfg.SyncIntervalSeconds <= 0 {
|
||||||
|
return fmt.Errorf("syncInterval must be positive")
|
||||||
|
}
|
||||||
|
if nodeCfg.TrafficFlushSeconds <= 0 {
|
||||||
|
return fmt.Errorf("trafficFlushInterval must be positive")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// WriteSettingToJSON writes a single setting key to the JSON config file.
|
// WriteSettingToJSON writes a single setting key to the JSON config file.
|
||||||
// It reads the existing file, updates the value, and writes back.
|
// It reads the existing file, updates the value, and writes back.
|
||||||
func WriteSettingToJSON(key, value string) error {
|
func WriteSettingToJSON(key, value string) error {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package config
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
@ -132,6 +133,72 @@ func TestGetLogFolderCustom(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeTestSettingsFile(t *testing.T, settings map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
data, err := json.MarshalIndent(settings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalIndent error: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(GetSettingPath(), data, 0644); err != nil {
|
||||||
|
t.Fatalf("WriteFile error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNodeConfigFromJSONDefaults(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||||
|
writeTestSettingsFile(t, map[string]any{})
|
||||||
|
|
||||||
|
cfg := GetNodeConfigFromJSON()
|
||||||
|
if cfg.Role != NodeRoleMaster {
|
||||||
|
t.Fatalf("expected default role %q, got %q", NodeRoleMaster, cfg.Role)
|
||||||
|
}
|
||||||
|
if cfg.NodeID != "" {
|
||||||
|
t.Fatalf("expected empty default node id, got %q", cfg.NodeID)
|
||||||
|
}
|
||||||
|
if cfg.SyncIntervalSeconds != 30 {
|
||||||
|
t.Fatalf("expected default sync interval 30, got %d", cfg.SyncIntervalSeconds)
|
||||||
|
}
|
||||||
|
if cfg.TrafficFlushSeconds != 10 {
|
||||||
|
t.Fatalf("expected default traffic flush interval 10, got %d", cfg.TrafficFlushSeconds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNodeConfigWorkerRequiresNodeID(t *testing.T) {
|
||||||
|
err := ValidateNodeConfig(NodeConfig{
|
||||||
|
Role: NodeRoleWorker,
|
||||||
|
SyncIntervalSeconds: 30,
|
||||||
|
TrafficFlushSeconds: 10,
|
||||||
|
}, DBConfig{Type: "mariadb"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected worker without node id to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNodeConfigWorkerRequiresMariaDB(t *testing.T) {
|
||||||
|
err := ValidateNodeConfig(NodeConfig{
|
||||||
|
Role: NodeRoleWorker,
|
||||||
|
NodeID: "worker-1",
|
||||||
|
SyncIntervalSeconds: 30,
|
||||||
|
TrafficFlushSeconds: 10,
|
||||||
|
}, DBConfig{Type: "sqlite"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected worker on sqlite to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharedRuntimeFilePaths(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||||
|
|
||||||
|
if got := GetSharedCachePath(); got != filepath.Join(tmpDir, "shared-cache.json") {
|
||||||
|
t.Fatalf("unexpected shared cache path: %s", got)
|
||||||
|
}
|
||||||
|
if got := GetTrafficPendingPath(); got != filepath.Join(tmpDir, "traffic-pending.json") {
|
||||||
|
t.Fatalf("unexpected traffic pending path: %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) {
|
func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||||
|
|
|
||||||
82
main.go
82
main.go
|
|
@ -29,6 +29,12 @@ import (
|
||||||
func runWebServer() {
|
func runWebServer() {
|
||||||
log.Printf("Starting %v %v", config.GetName(), config.GetVersion())
|
log.Printf("Starting %v %v", config.GetName(), config.GetVersion())
|
||||||
|
|
||||||
|
dbCfg := config.GetDBConfigFromJSON()
|
||||||
|
nodeCfg := config.GetNodeConfigFromJSON()
|
||||||
|
if err := config.ValidateNodeConfig(nodeCfg, dbCfg); err != nil {
|
||||||
|
log.Fatalf("invalid node configuration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
switch config.GetLogLevel() {
|
switch config.GetLogLevel() {
|
||||||
case config.Debug:
|
case config.Debug:
|
||||||
logger.InitLogger(logging.DEBUG)
|
logger.InitLogger(logging.DEBUG)
|
||||||
|
|
@ -199,6 +205,11 @@ func showSetting(show bool) {
|
||||||
fmt.Println("port:", port)
|
fmt.Println("port:", port)
|
||||||
fmt.Println("webDomain:", webDomain)
|
fmt.Println("webDomain:", webDomain)
|
||||||
fmt.Println("webBasePath:", webBasePath)
|
fmt.Println("webBasePath:", webBasePath)
|
||||||
|
nodeCfg := config.GetNodeConfigFromJSON()
|
||||||
|
fmt.Println("nodeRole:", nodeCfg.Role)
|
||||||
|
fmt.Println("nodeId:", nodeCfg.NodeID)
|
||||||
|
fmt.Println("syncInterval:", nodeCfg.SyncIntervalSeconds)
|
||||||
|
fmt.Println("trafficFlushInterval:", nodeCfg.TrafficFlushSeconds)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -525,6 +536,10 @@ func main() {
|
||||||
var dbPassword string
|
var dbPassword string
|
||||||
var dbName string
|
var dbName string
|
||||||
var showDbType bool
|
var showDbType bool
|
||||||
|
var nodeRoleFlag string
|
||||||
|
var nodeIDFlag string
|
||||||
|
var syncIntervalFlag int
|
||||||
|
var trafficFlushIntervalFlag int
|
||||||
settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)")
|
settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)")
|
||||||
settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host")
|
settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host")
|
||||||
settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port")
|
settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port")
|
||||||
|
|
@ -532,6 +547,10 @@ func main() {
|
||||||
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
|
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
|
||||||
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
|
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
|
||||||
settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit")
|
settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit")
|
||||||
|
settingCmd.StringVar(&nodeRoleFlag, "nodeRole", "", "Set node role (master or worker)")
|
||||||
|
settingCmd.StringVar(&nodeIDFlag, "nodeId", "", "Set node identifier")
|
||||||
|
settingCmd.IntVar(&syncIntervalFlag, "syncInterval", 0, "Set shared sync interval in seconds")
|
||||||
|
settingCmd.IntVar(&trafficFlushIntervalFlag, "trafficFlushInterval", 0, "Set traffic flush interval in seconds")
|
||||||
|
|
||||||
migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError)
|
migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError)
|
||||||
var migrateDirection string
|
var migrateDirection string
|
||||||
|
|
@ -582,6 +601,22 @@ func main() {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
nodeRoleSet := false
|
||||||
|
nodeIDSet := false
|
||||||
|
syncIntervalSet := false
|
||||||
|
trafficFlushIntervalSet := false
|
||||||
|
settingCmd.Visit(func(f *flag.Flag) {
|
||||||
|
switch f.Name {
|
||||||
|
case "nodeRole":
|
||||||
|
nodeRoleSet = true
|
||||||
|
case "nodeId":
|
||||||
|
nodeIDSet = true
|
||||||
|
case "syncInterval":
|
||||||
|
syncIntervalSet = true
|
||||||
|
case "trafficFlushInterval":
|
||||||
|
trafficFlushIntervalSet = true
|
||||||
|
}
|
||||||
|
})
|
||||||
if showDbType {
|
if showDbType {
|
||||||
fmt.Println(config.GetDBTypeFromJSON())
|
fmt.Println(config.GetDBTypeFromJSON())
|
||||||
return
|
return
|
||||||
|
|
@ -648,6 +683,53 @@ func main() {
|
||||||
fmt.Println("dbName set to:", dbName)
|
fmt.Println("dbName set to:", dbName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if nodeRoleSet || nodeIDSet || syncIntervalSet || trafficFlushIntervalSet {
|
||||||
|
candidate := config.GetNodeConfigFromJSON()
|
||||||
|
if nodeRoleSet {
|
||||||
|
candidate.Role = config.NodeRole(nodeRoleFlag)
|
||||||
|
}
|
||||||
|
if nodeIDSet {
|
||||||
|
candidate.NodeID = nodeIDFlag
|
||||||
|
}
|
||||||
|
if syncIntervalSet {
|
||||||
|
candidate.SyncIntervalSeconds = syncIntervalFlag
|
||||||
|
}
|
||||||
|
if trafficFlushIntervalSet {
|
||||||
|
candidate.TrafficFlushSeconds = trafficFlushIntervalFlag
|
||||||
|
}
|
||||||
|
if err := config.ValidateNodeConfig(candidate, config.GetDBConfigFromJSON()); err != nil {
|
||||||
|
fmt.Println("Invalid node settings:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if nodeRoleSet {
|
||||||
|
if err := config.WriteSettingToJSON("nodeRole", nodeRoleFlag); err != nil {
|
||||||
|
fmt.Println("Failed to set nodeRole:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("nodeRole set to:", nodeRoleFlag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nodeIDSet {
|
||||||
|
if err := config.WriteSettingToJSON("nodeId", nodeIDFlag); err != nil {
|
||||||
|
fmt.Println("Failed to set nodeId:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("nodeId set to:", nodeIDFlag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if syncIntervalSet {
|
||||||
|
if err := config.WriteSettingToJSON("syncInterval", fmt.Sprintf("%d", syncIntervalFlag)); err != nil {
|
||||||
|
fmt.Println("Failed to set syncInterval:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("syncInterval set to:", syncIntervalFlag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if trafficFlushIntervalSet {
|
||||||
|
if err := config.WriteSettingToJSON("trafficFlushInterval", fmt.Sprintf("%d", trafficFlushIntervalFlag)); err != nil {
|
||||||
|
fmt.Println("Failed to set trafficFlushInterval:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("trafficFlushInterval set to:", trafficFlushIntervalFlag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case "cert":
|
case "cert":
|
||||||
err := settingCmd.Parse(os.Args[2:])
|
err := settingCmd.Parse(os.Args[2:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue