From 36826706ec154df22e650df5169414127959a310 Mon Sep 17 00:00:00 2001 From: Sora39831 <540587985@qq.com> Date: Fri, 10 Apr 2026 11:13:52 +0800 Subject: [PATCH] feat: add node config and startup validation --- config/config.go | 107 ++++++++++++++++++++++++++++++++++++++++++ config/config_test.go | 67 ++++++++++++++++++++++++++ main.go | 82 ++++++++++++++++++++++++++++++++ 3 files changed, 256 insertions(+) diff --git a/config/config.go b/config/config.go index f17f81b5..2c8a8b85 100644 --- a/config/config.go +++ b/config/config.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" ) @@ -31,6 +32,20 @@ const ( Error LogLevel = "error" ) +type NodeRole string + +const ( + NodeRoleMaster NodeRole = "master" + NodeRoleWorker NodeRole = "worker" +) + +type NodeConfig struct { + Role NodeRole + NodeID string + SyncIntervalSeconds int + TrafficFlushSeconds int +} + // GetVersion returns the version string of the 3x-ui application. func GetVersion() string { return strings.TrimSpace(version) @@ -106,6 +121,14 @@ func GetSettingPath() string { return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName()) } +func GetSharedCachePath() string { + return filepath.Join(GetDBFolderPath(), "shared-cache.json") +} + +func GetTrafficPendingPath() string { + return filepath.Join(GetDBFolderPath(), "traffic-pending.json") +} + // GetLogFolder returns the path to the log folder based on environment variables or platform defaults. func GetLogFolder() string { logFolderPath := os.Getenv("XUI_LOG_FOLDER") @@ -125,6 +148,14 @@ var settingGroupAliases = map[string][]string{ "dbUser": {"databaseConnection", "other"}, "dbPassword": {"databaseConnection", "other"}, "dbName": {"databaseConnection", "other"}, + "nodeRole": {"other"}, + "nodeId": {"other"}, + "syncInterval": { + "other", + }, + "trafficFlushInterval": { + "other", + }, } func readGroupedString(settings map[string]any, key string) string { @@ -143,6 +174,37 @@ func readGroupedString(settings map[string]any, key string) string { return "" } +func readGroupedInt(settings map[string]any, key string, fallback int) int { + readInt := func(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case int: + return v, true + case string: + i, err := strconv.Atoi(v) + if err == nil { + return i, true + } + } + return 0, false + } + + if groups, ok := settingGroupAliases[key]; ok { + for _, groupName := range groups { + if group, ok := settings[groupName].(map[string]any); ok { + if value, ok := readInt(group[key]); ok { + return value + } + } + } + } + if value, ok := readInt(settings[key]); ok { + return value + } + return fallback +} + func settingsLayoutMeta() map[string]any { return map[string]any{ "layout": "按模块-用途来归类", @@ -250,6 +312,51 @@ func GetDBConfigFromJSON() DBConfig { } } +func GetNodeConfigFromJSON() NodeConfig { + data, err := os.ReadFile(GetSettingPath()) + if err != nil { + return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10} + } + + var settings map[string]any + if err := json.Unmarshal(data, &settings); err != nil { + return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10} + } + + role := readGroupedString(settings, "nodeRole") + if role == "" { + role = string(NodeRoleMaster) + } + + return NodeConfig{ + Role: NodeRole(role), + NodeID: readGroupedString(settings, "nodeId"), + SyncIntervalSeconds: readGroupedInt(settings, "syncInterval", 30), + TrafficFlushSeconds: readGroupedInt(settings, "trafficFlushInterval", 10), + } +} + +func ValidateNodeConfig(nodeCfg NodeConfig, dbCfg DBConfig) error { + switch nodeCfg.Role { + case NodeRoleMaster, NodeRoleWorker: + default: + return fmt.Errorf("invalid nodeRole %q", nodeCfg.Role) + } + if nodeCfg.Role == NodeRoleWorker && nodeCfg.NodeID == "" { + return fmt.Errorf("worker mode requires nodeId") + } + if nodeCfg.Role == NodeRoleWorker && dbCfg.Type != "mariadb" { + return fmt.Errorf("worker mode requires mariadb") + } + if nodeCfg.SyncIntervalSeconds <= 0 { + return fmt.Errorf("syncInterval must be positive") + } + if nodeCfg.TrafficFlushSeconds <= 0 { + return fmt.Errorf("trafficFlushInterval must be positive") + } + return nil +} + // WriteSettingToJSON writes a single setting key to the JSON config file. // It reads the existing file, updates the value, and writes back. func WriteSettingToJSON(key, value string) error { diff --git a/config/config_test.go b/config/config_test.go index 2cc8ce8e..3f0cb147 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "os" + "path/filepath" "strings" "testing" ) @@ -132,6 +133,72 @@ func TestGetLogFolderCustom(t *testing.T) { } } +func writeTestSettingsFile(t *testing.T, settings map[string]any) { + t.Helper() + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + t.Fatalf("MarshalIndent error: %v", err) + } + if err := os.WriteFile(GetSettingPath(), data, 0644); err != nil { + t.Fatalf("WriteFile error: %v", err) + } +} + +func TestGetNodeConfigFromJSONDefaults(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XUI_DB_FOLDER", tmpDir) + writeTestSettingsFile(t, map[string]any{}) + + cfg := GetNodeConfigFromJSON() + if cfg.Role != NodeRoleMaster { + t.Fatalf("expected default role %q, got %q", NodeRoleMaster, cfg.Role) + } + if cfg.NodeID != "" { + t.Fatalf("expected empty default node id, got %q", cfg.NodeID) + } + if cfg.SyncIntervalSeconds != 30 { + t.Fatalf("expected default sync interval 30, got %d", cfg.SyncIntervalSeconds) + } + if cfg.TrafficFlushSeconds != 10 { + t.Fatalf("expected default traffic flush interval 10, got %d", cfg.TrafficFlushSeconds) + } +} + +func TestValidateNodeConfigWorkerRequiresNodeID(t *testing.T) { + err := ValidateNodeConfig(NodeConfig{ + Role: NodeRoleWorker, + SyncIntervalSeconds: 30, + TrafficFlushSeconds: 10, + }, DBConfig{Type: "mariadb"}) + if err == nil { + t.Fatal("expected worker without node id to fail validation") + } +} + +func TestValidateNodeConfigWorkerRequiresMariaDB(t *testing.T) { + err := ValidateNodeConfig(NodeConfig{ + Role: NodeRoleWorker, + NodeID: "worker-1", + SyncIntervalSeconds: 30, + TrafficFlushSeconds: 10, + }, DBConfig{Type: "sqlite"}) + if err == nil { + t.Fatal("expected worker on sqlite to fail validation") + } +} + +func TestSharedRuntimeFilePaths(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XUI_DB_FOLDER", tmpDir) + + if got := GetSharedCachePath(); got != filepath.Join(tmpDir, "shared-cache.json") { + t.Fatalf("unexpected shared cache path: %s", got) + } + if got := GetTrafficPendingPath(); got != filepath.Join(tmpDir, "traffic-pending.json") { + t.Fatalf("unexpected traffic pending path: %s", got) + } +} + func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) { tmpDir := t.TempDir() t.Setenv("XUI_DB_FOLDER", tmpDir) diff --git a/main.go b/main.go index 8c9f02bc..4e030470 100644 --- a/main.go +++ b/main.go @@ -29,6 +29,12 @@ import ( func runWebServer() { log.Printf("Starting %v %v", config.GetName(), config.GetVersion()) + dbCfg := config.GetDBConfigFromJSON() + nodeCfg := config.GetNodeConfigFromJSON() + if err := config.ValidateNodeConfig(nodeCfg, dbCfg); err != nil { + log.Fatalf("invalid node configuration: %v", err) + } + switch config.GetLogLevel() { case config.Debug: logger.InitLogger(logging.DEBUG) @@ -199,6 +205,11 @@ func showSetting(show bool) { fmt.Println("port:", port) fmt.Println("webDomain:", webDomain) fmt.Println("webBasePath:", webBasePath) + nodeCfg := config.GetNodeConfigFromJSON() + fmt.Println("nodeRole:", nodeCfg.Role) + fmt.Println("nodeId:", nodeCfg.NodeID) + fmt.Println("syncInterval:", nodeCfg.SyncIntervalSeconds) + fmt.Println("trafficFlushInterval:", nodeCfg.TrafficFlushSeconds) } } @@ -525,6 +536,10 @@ func main() { var dbPassword string var dbName string var showDbType bool + var nodeRoleFlag string + var nodeIDFlag string + var syncIntervalFlag int + var trafficFlushIntervalFlag int settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)") settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host") settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port") @@ -532,6 +547,10 @@ func main() { settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password") settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name") settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit") + settingCmd.StringVar(&nodeRoleFlag, "nodeRole", "", "Set node role (master or worker)") + settingCmd.StringVar(&nodeIDFlag, "nodeId", "", "Set node identifier") + settingCmd.IntVar(&syncIntervalFlag, "syncInterval", 0, "Set shared sync interval in seconds") + settingCmd.IntVar(&trafficFlushIntervalFlag, "trafficFlushInterval", 0, "Set traffic flush interval in seconds") migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError) var migrateDirection string @@ -582,6 +601,22 @@ func main() { fmt.Println(err) return } + nodeRoleSet := false + nodeIDSet := false + syncIntervalSet := false + trafficFlushIntervalSet := false + settingCmd.Visit(func(f *flag.Flag) { + switch f.Name { + case "nodeRole": + nodeRoleSet = true + case "nodeId": + nodeIDSet = true + case "syncInterval": + syncIntervalSet = true + case "trafficFlushInterval": + trafficFlushIntervalSet = true + } + }) if showDbType { fmt.Println(config.GetDBTypeFromJSON()) return @@ -648,6 +683,53 @@ func main() { fmt.Println("dbName set to:", dbName) } } + if nodeRoleSet || nodeIDSet || syncIntervalSet || trafficFlushIntervalSet { + candidate := config.GetNodeConfigFromJSON() + if nodeRoleSet { + candidate.Role = config.NodeRole(nodeRoleFlag) + } + if nodeIDSet { + candidate.NodeID = nodeIDFlag + } + if syncIntervalSet { + candidate.SyncIntervalSeconds = syncIntervalFlag + } + if trafficFlushIntervalSet { + candidate.TrafficFlushSeconds = trafficFlushIntervalFlag + } + if err := config.ValidateNodeConfig(candidate, config.GetDBConfigFromJSON()); err != nil { + fmt.Println("Invalid node settings:", err) + return + } + if nodeRoleSet { + if err := config.WriteSettingToJSON("nodeRole", nodeRoleFlag); err != nil { + fmt.Println("Failed to set nodeRole:", err) + } else { + fmt.Println("nodeRole set to:", nodeRoleFlag) + } + } + if nodeIDSet { + if err := config.WriteSettingToJSON("nodeId", nodeIDFlag); err != nil { + fmt.Println("Failed to set nodeId:", err) + } else { + fmt.Println("nodeId set to:", nodeIDFlag) + } + } + if syncIntervalSet { + if err := config.WriteSettingToJSON("syncInterval", fmt.Sprintf("%d", syncIntervalFlag)); err != nil { + fmt.Println("Failed to set syncInterval:", err) + } else { + fmt.Println("syncInterval set to:", syncIntervalFlag) + } + } + if trafficFlushIntervalSet { + if err := config.WriteSettingToJSON("trafficFlushInterval", fmt.Sprintf("%d", trafficFlushIntervalFlag)); err != nil { + fmt.Println("Failed to set trafficFlushInterval:", err) + } else { + fmt.Println("trafficFlushInterval set to:", trafficFlushIntervalFlag) + } + } + } case "cert": err := settingCmd.Parse(os.Args[2:]) if err != nil {