feat: add node config and startup validation

This commit is contained in:
Sora39831 2026-04-10 11:13:52 +08:00
parent 735e1d8ef0
commit 36826706ec
3 changed files with 256 additions and 0 deletions

View file

@ -10,6 +10,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
) )
@ -31,6 +32,20 @@ const (
Error LogLevel = "error" Error LogLevel = "error"
) )
type NodeRole string
const (
NodeRoleMaster NodeRole = "master"
NodeRoleWorker NodeRole = "worker"
)
type NodeConfig struct {
Role NodeRole
NodeID string
SyncIntervalSeconds int
TrafficFlushSeconds int
}
// GetVersion returns the version string of the 3x-ui application. // GetVersion returns the version string of the 3x-ui application.
func GetVersion() string { func GetVersion() string {
return strings.TrimSpace(version) return strings.TrimSpace(version)
@ -106,6 +121,14 @@ func GetSettingPath() string {
return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName()) return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName())
} }
func GetSharedCachePath() string {
return filepath.Join(GetDBFolderPath(), "shared-cache.json")
}
func GetTrafficPendingPath() string {
return filepath.Join(GetDBFolderPath(), "traffic-pending.json")
}
// GetLogFolder returns the path to the log folder based on environment variables or platform defaults. // GetLogFolder returns the path to the log folder based on environment variables or platform defaults.
func GetLogFolder() string { func GetLogFolder() string {
logFolderPath := os.Getenv("XUI_LOG_FOLDER") logFolderPath := os.Getenv("XUI_LOG_FOLDER")
@ -125,6 +148,14 @@ var settingGroupAliases = map[string][]string{
"dbUser": {"databaseConnection", "other"}, "dbUser": {"databaseConnection", "other"},
"dbPassword": {"databaseConnection", "other"}, "dbPassword": {"databaseConnection", "other"},
"dbName": {"databaseConnection", "other"}, "dbName": {"databaseConnection", "other"},
"nodeRole": {"other"},
"nodeId": {"other"},
"syncInterval": {
"other",
},
"trafficFlushInterval": {
"other",
},
} }
func readGroupedString(settings map[string]any, key string) string { func readGroupedString(settings map[string]any, key string) string {
@ -143,6 +174,37 @@ func readGroupedString(settings map[string]any, key string) string {
return "" return ""
} }
func readGroupedInt(settings map[string]any, key string, fallback int) int {
readInt := func(value any) (int, bool) {
switch v := value.(type) {
case float64:
return int(v), true
case int:
return v, true
case string:
i, err := strconv.Atoi(v)
if err == nil {
return i, true
}
}
return 0, false
}
if groups, ok := settingGroupAliases[key]; ok {
for _, groupName := range groups {
if group, ok := settings[groupName].(map[string]any); ok {
if value, ok := readInt(group[key]); ok {
return value
}
}
}
}
if value, ok := readInt(settings[key]); ok {
return value
}
return fallback
}
func settingsLayoutMeta() map[string]any { func settingsLayoutMeta() map[string]any {
return map[string]any{ return map[string]any{
"layout": "按模块-用途来归类", "layout": "按模块-用途来归类",
@ -250,6 +312,51 @@ func GetDBConfigFromJSON() DBConfig {
} }
} }
func GetNodeConfigFromJSON() NodeConfig {
data, err := os.ReadFile(GetSettingPath())
if err != nil {
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
}
role := readGroupedString(settings, "nodeRole")
if role == "" {
role = string(NodeRoleMaster)
}
return NodeConfig{
Role: NodeRole(role),
NodeID: readGroupedString(settings, "nodeId"),
SyncIntervalSeconds: readGroupedInt(settings, "syncInterval", 30),
TrafficFlushSeconds: readGroupedInt(settings, "trafficFlushInterval", 10),
}
}
func ValidateNodeConfig(nodeCfg NodeConfig, dbCfg DBConfig) error {
switch nodeCfg.Role {
case NodeRoleMaster, NodeRoleWorker:
default:
return fmt.Errorf("invalid nodeRole %q", nodeCfg.Role)
}
if nodeCfg.Role == NodeRoleWorker && nodeCfg.NodeID == "" {
return fmt.Errorf("worker mode requires nodeId")
}
if nodeCfg.Role == NodeRoleWorker && dbCfg.Type != "mariadb" {
return fmt.Errorf("worker mode requires mariadb")
}
if nodeCfg.SyncIntervalSeconds <= 0 {
return fmt.Errorf("syncInterval must be positive")
}
if nodeCfg.TrafficFlushSeconds <= 0 {
return fmt.Errorf("trafficFlushInterval must be positive")
}
return nil
}
// WriteSettingToJSON writes a single setting key to the JSON config file. // WriteSettingToJSON writes a single setting key to the JSON config file.
// It reads the existing file, updates the value, and writes back. // It reads the existing file, updates the value, and writes back.
func WriteSettingToJSON(key, value string) error { func WriteSettingToJSON(key, value string) error {

View file

@ -3,6 +3,7 @@ package config
import ( import (
"encoding/json" "encoding/json"
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
) )
@ -132,6 +133,72 @@ func TestGetLogFolderCustom(t *testing.T) {
} }
} }
func writeTestSettingsFile(t *testing.T, settings map[string]any) {
t.Helper()
data, err := json.MarshalIndent(settings, "", " ")
if err != nil {
t.Fatalf("MarshalIndent error: %v", err)
}
if err := os.WriteFile(GetSettingPath(), data, 0644); err != nil {
t.Fatalf("WriteFile error: %v", err)
}
}
func TestGetNodeConfigFromJSONDefaults(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DB_FOLDER", tmpDir)
writeTestSettingsFile(t, map[string]any{})
cfg := GetNodeConfigFromJSON()
if cfg.Role != NodeRoleMaster {
t.Fatalf("expected default role %q, got %q", NodeRoleMaster, cfg.Role)
}
if cfg.NodeID != "" {
t.Fatalf("expected empty default node id, got %q", cfg.NodeID)
}
if cfg.SyncIntervalSeconds != 30 {
t.Fatalf("expected default sync interval 30, got %d", cfg.SyncIntervalSeconds)
}
if cfg.TrafficFlushSeconds != 10 {
t.Fatalf("expected default traffic flush interval 10, got %d", cfg.TrafficFlushSeconds)
}
}
func TestValidateNodeConfigWorkerRequiresNodeID(t *testing.T) {
err := ValidateNodeConfig(NodeConfig{
Role: NodeRoleWorker,
SyncIntervalSeconds: 30,
TrafficFlushSeconds: 10,
}, DBConfig{Type: "mariadb"})
if err == nil {
t.Fatal("expected worker without node id to fail validation")
}
}
func TestValidateNodeConfigWorkerRequiresMariaDB(t *testing.T) {
err := ValidateNodeConfig(NodeConfig{
Role: NodeRoleWorker,
NodeID: "worker-1",
SyncIntervalSeconds: 30,
TrafficFlushSeconds: 10,
}, DBConfig{Type: "sqlite"})
if err == nil {
t.Fatal("expected worker on sqlite to fail validation")
}
}
func TestSharedRuntimeFilePaths(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DB_FOLDER", tmpDir)
if got := GetSharedCachePath(); got != filepath.Join(tmpDir, "shared-cache.json") {
t.Fatalf("unexpected shared cache path: %s", got)
}
if got := GetTrafficPendingPath(); got != filepath.Join(tmpDir, "traffic-pending.json") {
t.Fatalf("unexpected traffic pending path: %s", got)
}
}
func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) { func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
t.Setenv("XUI_DB_FOLDER", tmpDir) t.Setenv("XUI_DB_FOLDER", tmpDir)

82
main.go
View file

@ -29,6 +29,12 @@ import (
func runWebServer() { func runWebServer() {
log.Printf("Starting %v %v", config.GetName(), config.GetVersion()) log.Printf("Starting %v %v", config.GetName(), config.GetVersion())
dbCfg := config.GetDBConfigFromJSON()
nodeCfg := config.GetNodeConfigFromJSON()
if err := config.ValidateNodeConfig(nodeCfg, dbCfg); err != nil {
log.Fatalf("invalid node configuration: %v", err)
}
switch config.GetLogLevel() { switch config.GetLogLevel() {
case config.Debug: case config.Debug:
logger.InitLogger(logging.DEBUG) logger.InitLogger(logging.DEBUG)
@ -199,6 +205,11 @@ func showSetting(show bool) {
fmt.Println("port:", port) fmt.Println("port:", port)
fmt.Println("webDomain:", webDomain) fmt.Println("webDomain:", webDomain)
fmt.Println("webBasePath:", webBasePath) fmt.Println("webBasePath:", webBasePath)
nodeCfg := config.GetNodeConfigFromJSON()
fmt.Println("nodeRole:", nodeCfg.Role)
fmt.Println("nodeId:", nodeCfg.NodeID)
fmt.Println("syncInterval:", nodeCfg.SyncIntervalSeconds)
fmt.Println("trafficFlushInterval:", nodeCfg.TrafficFlushSeconds)
} }
} }
@ -525,6 +536,10 @@ func main() {
var dbPassword string var dbPassword string
var dbName string var dbName string
var showDbType bool var showDbType bool
var nodeRoleFlag string
var nodeIDFlag string
var syncIntervalFlag int
var trafficFlushIntervalFlag int
settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)") settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)")
settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host") settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host")
settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port") settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port")
@ -532,6 +547,10 @@ func main() {
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password") settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name") settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit") settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit")
settingCmd.StringVar(&nodeRoleFlag, "nodeRole", "", "Set node role (master or worker)")
settingCmd.StringVar(&nodeIDFlag, "nodeId", "", "Set node identifier")
settingCmd.IntVar(&syncIntervalFlag, "syncInterval", 0, "Set shared sync interval in seconds")
settingCmd.IntVar(&trafficFlushIntervalFlag, "trafficFlushInterval", 0, "Set traffic flush interval in seconds")
migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError) migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError)
var migrateDirection string var migrateDirection string
@ -582,6 +601,22 @@ func main() {
fmt.Println(err) fmt.Println(err)
return return
} }
nodeRoleSet := false
nodeIDSet := false
syncIntervalSet := false
trafficFlushIntervalSet := false
settingCmd.Visit(func(f *flag.Flag) {
switch f.Name {
case "nodeRole":
nodeRoleSet = true
case "nodeId":
nodeIDSet = true
case "syncInterval":
syncIntervalSet = true
case "trafficFlushInterval":
trafficFlushIntervalSet = true
}
})
if showDbType { if showDbType {
fmt.Println(config.GetDBTypeFromJSON()) fmt.Println(config.GetDBTypeFromJSON())
return return
@ -648,6 +683,53 @@ func main() {
fmt.Println("dbName set to:", dbName) fmt.Println("dbName set to:", dbName)
} }
} }
if nodeRoleSet || nodeIDSet || syncIntervalSet || trafficFlushIntervalSet {
candidate := config.GetNodeConfigFromJSON()
if nodeRoleSet {
candidate.Role = config.NodeRole(nodeRoleFlag)
}
if nodeIDSet {
candidate.NodeID = nodeIDFlag
}
if syncIntervalSet {
candidate.SyncIntervalSeconds = syncIntervalFlag
}
if trafficFlushIntervalSet {
candidate.TrafficFlushSeconds = trafficFlushIntervalFlag
}
if err := config.ValidateNodeConfig(candidate, config.GetDBConfigFromJSON()); err != nil {
fmt.Println("Invalid node settings:", err)
return
}
if nodeRoleSet {
if err := config.WriteSettingToJSON("nodeRole", nodeRoleFlag); err != nil {
fmt.Println("Failed to set nodeRole:", err)
} else {
fmt.Println("nodeRole set to:", nodeRoleFlag)
}
}
if nodeIDSet {
if err := config.WriteSettingToJSON("nodeId", nodeIDFlag); err != nil {
fmt.Println("Failed to set nodeId:", err)
} else {
fmt.Println("nodeId set to:", nodeIDFlag)
}
}
if syncIntervalSet {
if err := config.WriteSettingToJSON("syncInterval", fmt.Sprintf("%d", syncIntervalFlag)); err != nil {
fmt.Println("Failed to set syncInterval:", err)
} else {
fmt.Println("syncInterval set to:", syncIntervalFlag)
}
}
if trafficFlushIntervalSet {
if err := config.WriteSettingToJSON("trafficFlushInterval", fmt.Sprintf("%d", trafficFlushIntervalFlag)); err != nil {
fmt.Println("Failed to set trafficFlushInterval:", err)
} else {
fmt.Println("trafficFlushInterval set to:", trafficFlushIntervalFlag)
}
}
}
case "cert": case "cert":
err := settingCmd.Parse(os.Args[2:]) err := settingCmd.Parse(os.Args[2:])
if err != nil { if err != nil {