feat: add node config and startup validation

This commit is contained in:
Sora39831 2026-04-10 11:13:52 +08:00
parent 735e1d8ef0
commit 36826706ec
3 changed files with 256 additions and 0 deletions

View file

@ -10,6 +10,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
)
@ -31,6 +32,20 @@ const (
Error LogLevel = "error"
)
type NodeRole string
const (
NodeRoleMaster NodeRole = "master"
NodeRoleWorker NodeRole = "worker"
)
type NodeConfig struct {
Role NodeRole
NodeID string
SyncIntervalSeconds int
TrafficFlushSeconds int
}
// GetVersion returns the version string of the 3x-ui application.
func GetVersion() string {
return strings.TrimSpace(version)
@ -106,6 +121,14 @@ func GetSettingPath() string {
return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName())
}
func GetSharedCachePath() string {
return filepath.Join(GetDBFolderPath(), "shared-cache.json")
}
func GetTrafficPendingPath() string {
return filepath.Join(GetDBFolderPath(), "traffic-pending.json")
}
// GetLogFolder returns the path to the log folder based on environment variables or platform defaults.
func GetLogFolder() string {
logFolderPath := os.Getenv("XUI_LOG_FOLDER")
@ -125,6 +148,14 @@ var settingGroupAliases = map[string][]string{
"dbUser": {"databaseConnection", "other"},
"dbPassword": {"databaseConnection", "other"},
"dbName": {"databaseConnection", "other"},
"nodeRole": {"other"},
"nodeId": {"other"},
"syncInterval": {
"other",
},
"trafficFlushInterval": {
"other",
},
}
func readGroupedString(settings map[string]any, key string) string {
@ -143,6 +174,37 @@ func readGroupedString(settings map[string]any, key string) string {
return ""
}
func readGroupedInt(settings map[string]any, key string, fallback int) int {
readInt := func(value any) (int, bool) {
switch v := value.(type) {
case float64:
return int(v), true
case int:
return v, true
case string:
i, err := strconv.Atoi(v)
if err == nil {
return i, true
}
}
return 0, false
}
if groups, ok := settingGroupAliases[key]; ok {
for _, groupName := range groups {
if group, ok := settings[groupName].(map[string]any); ok {
if value, ok := readInt(group[key]); ok {
return value
}
}
}
}
if value, ok := readInt(settings[key]); ok {
return value
}
return fallback
}
func settingsLayoutMeta() map[string]any {
return map[string]any{
"layout": "按模块-用途来归类",
@ -250,6 +312,51 @@ func GetDBConfigFromJSON() DBConfig {
}
}
func GetNodeConfigFromJSON() NodeConfig {
data, err := os.ReadFile(GetSettingPath())
if err != nil {
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
}
role := readGroupedString(settings, "nodeRole")
if role == "" {
role = string(NodeRoleMaster)
}
return NodeConfig{
Role: NodeRole(role),
NodeID: readGroupedString(settings, "nodeId"),
SyncIntervalSeconds: readGroupedInt(settings, "syncInterval", 30),
TrafficFlushSeconds: readGroupedInt(settings, "trafficFlushInterval", 10),
}
}
func ValidateNodeConfig(nodeCfg NodeConfig, dbCfg DBConfig) error {
switch nodeCfg.Role {
case NodeRoleMaster, NodeRoleWorker:
default:
return fmt.Errorf("invalid nodeRole %q", nodeCfg.Role)
}
if nodeCfg.Role == NodeRoleWorker && nodeCfg.NodeID == "" {
return fmt.Errorf("worker mode requires nodeId")
}
if nodeCfg.Role == NodeRoleWorker && dbCfg.Type != "mariadb" {
return fmt.Errorf("worker mode requires mariadb")
}
if nodeCfg.SyncIntervalSeconds <= 0 {
return fmt.Errorf("syncInterval must be positive")
}
if nodeCfg.TrafficFlushSeconds <= 0 {
return fmt.Errorf("trafficFlushInterval must be positive")
}
return nil
}
// WriteSettingToJSON writes a single setting key to the JSON config file.
// It reads the existing file, updates the value, and writes back.
func WriteSettingToJSON(key, value string) error {

View file

@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
)
@ -132,6 +133,72 @@ func TestGetLogFolderCustom(t *testing.T) {
}
}
func writeTestSettingsFile(t *testing.T, settings map[string]any) {
t.Helper()
data, err := json.MarshalIndent(settings, "", " ")
if err != nil {
t.Fatalf("MarshalIndent error: %v", err)
}
if err := os.WriteFile(GetSettingPath(), data, 0644); err != nil {
t.Fatalf("WriteFile error: %v", err)
}
}
func TestGetNodeConfigFromJSONDefaults(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DB_FOLDER", tmpDir)
writeTestSettingsFile(t, map[string]any{})
cfg := GetNodeConfigFromJSON()
if cfg.Role != NodeRoleMaster {
t.Fatalf("expected default role %q, got %q", NodeRoleMaster, cfg.Role)
}
if cfg.NodeID != "" {
t.Fatalf("expected empty default node id, got %q", cfg.NodeID)
}
if cfg.SyncIntervalSeconds != 30 {
t.Fatalf("expected default sync interval 30, got %d", cfg.SyncIntervalSeconds)
}
if cfg.TrafficFlushSeconds != 10 {
t.Fatalf("expected default traffic flush interval 10, got %d", cfg.TrafficFlushSeconds)
}
}
func TestValidateNodeConfigWorkerRequiresNodeID(t *testing.T) {
err := ValidateNodeConfig(NodeConfig{
Role: NodeRoleWorker,
SyncIntervalSeconds: 30,
TrafficFlushSeconds: 10,
}, DBConfig{Type: "mariadb"})
if err == nil {
t.Fatal("expected worker without node id to fail validation")
}
}
func TestValidateNodeConfigWorkerRequiresMariaDB(t *testing.T) {
err := ValidateNodeConfig(NodeConfig{
Role: NodeRoleWorker,
NodeID: "worker-1",
SyncIntervalSeconds: 30,
TrafficFlushSeconds: 10,
}, DBConfig{Type: "sqlite"})
if err == nil {
t.Fatal("expected worker on sqlite to fail validation")
}
}
func TestSharedRuntimeFilePaths(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DB_FOLDER", tmpDir)
if got := GetSharedCachePath(); got != filepath.Join(tmpDir, "shared-cache.json") {
t.Fatalf("unexpected shared cache path: %s", got)
}
if got := GetTrafficPendingPath(); got != filepath.Join(tmpDir, "traffic-pending.json") {
t.Fatalf("unexpected traffic pending path: %s", got)
}
}
func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XUI_DB_FOLDER", tmpDir)

82
main.go
View file

@ -29,6 +29,12 @@ import (
func runWebServer() {
log.Printf("Starting %v %v", config.GetName(), config.GetVersion())
dbCfg := config.GetDBConfigFromJSON()
nodeCfg := config.GetNodeConfigFromJSON()
if err := config.ValidateNodeConfig(nodeCfg, dbCfg); err != nil {
log.Fatalf("invalid node configuration: %v", err)
}
switch config.GetLogLevel() {
case config.Debug:
logger.InitLogger(logging.DEBUG)
@ -199,6 +205,11 @@ func showSetting(show bool) {
fmt.Println("port:", port)
fmt.Println("webDomain:", webDomain)
fmt.Println("webBasePath:", webBasePath)
nodeCfg := config.GetNodeConfigFromJSON()
fmt.Println("nodeRole:", nodeCfg.Role)
fmt.Println("nodeId:", nodeCfg.NodeID)
fmt.Println("syncInterval:", nodeCfg.SyncIntervalSeconds)
fmt.Println("trafficFlushInterval:", nodeCfg.TrafficFlushSeconds)
}
}
@ -525,6 +536,10 @@ func main() {
var dbPassword string
var dbName string
var showDbType bool
var nodeRoleFlag string
var nodeIDFlag string
var syncIntervalFlag int
var trafficFlushIntervalFlag int
settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)")
settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host")
settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port")
@ -532,6 +547,10 @@ func main() {
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit")
settingCmd.StringVar(&nodeRoleFlag, "nodeRole", "", "Set node role (master or worker)")
settingCmd.StringVar(&nodeIDFlag, "nodeId", "", "Set node identifier")
settingCmd.IntVar(&syncIntervalFlag, "syncInterval", 0, "Set shared sync interval in seconds")
settingCmd.IntVar(&trafficFlushIntervalFlag, "trafficFlushInterval", 0, "Set traffic flush interval in seconds")
migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError)
var migrateDirection string
@ -582,6 +601,22 @@ func main() {
fmt.Println(err)
return
}
nodeRoleSet := false
nodeIDSet := false
syncIntervalSet := false
trafficFlushIntervalSet := false
settingCmd.Visit(func(f *flag.Flag) {
switch f.Name {
case "nodeRole":
nodeRoleSet = true
case "nodeId":
nodeIDSet = true
case "syncInterval":
syncIntervalSet = true
case "trafficFlushInterval":
trafficFlushIntervalSet = true
}
})
if showDbType {
fmt.Println(config.GetDBTypeFromJSON())
return
@ -648,6 +683,53 @@ func main() {
fmt.Println("dbName set to:", dbName)
}
}
if nodeRoleSet || nodeIDSet || syncIntervalSet || trafficFlushIntervalSet {
candidate := config.GetNodeConfigFromJSON()
if nodeRoleSet {
candidate.Role = config.NodeRole(nodeRoleFlag)
}
if nodeIDSet {
candidate.NodeID = nodeIDFlag
}
if syncIntervalSet {
candidate.SyncIntervalSeconds = syncIntervalFlag
}
if trafficFlushIntervalSet {
candidate.TrafficFlushSeconds = trafficFlushIntervalFlag
}
if err := config.ValidateNodeConfig(candidate, config.GetDBConfigFromJSON()); err != nil {
fmt.Println("Invalid node settings:", err)
return
}
if nodeRoleSet {
if err := config.WriteSettingToJSON("nodeRole", nodeRoleFlag); err != nil {
fmt.Println("Failed to set nodeRole:", err)
} else {
fmt.Println("nodeRole set to:", nodeRoleFlag)
}
}
if nodeIDSet {
if err := config.WriteSettingToJSON("nodeId", nodeIDFlag); err != nil {
fmt.Println("Failed to set nodeId:", err)
} else {
fmt.Println("nodeId set to:", nodeIDFlag)
}
}
if syncIntervalSet {
if err := config.WriteSettingToJSON("syncInterval", fmt.Sprintf("%d", syncIntervalFlag)); err != nil {
fmt.Println("Failed to set syncInterval:", err)
} else {
fmt.Println("syncInterval set to:", syncIntervalFlag)
}
}
if trafficFlushIntervalSet {
if err := config.WriteSettingToJSON("trafficFlushInterval", fmt.Sprintf("%d", trafficFlushIntervalFlag)); err != nil {
fmt.Println("Failed to set trafficFlushInterval:", err)
} else {
fmt.Println("trafficFlushInterval set to:", trafficFlushIntervalFlag)
}
}
}
case "cert":
err := settingCmd.Parse(os.Args[2:])
if err != nil {