mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-07 05:34:17 +00:00
feat: add node config and startup validation
This commit is contained in:
parent
735e1d8ef0
commit
36826706ec
3 changed files with 256 additions and 0 deletions
107
config/config.go
107
config/config.go
|
|
@ -10,6 +10,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
|
@ -31,6 +32,20 @@ const (
|
|||
Error LogLevel = "error"
|
||||
)
|
||||
|
||||
type NodeRole string
|
||||
|
||||
const (
|
||||
NodeRoleMaster NodeRole = "master"
|
||||
NodeRoleWorker NodeRole = "worker"
|
||||
)
|
||||
|
||||
type NodeConfig struct {
|
||||
Role NodeRole
|
||||
NodeID string
|
||||
SyncIntervalSeconds int
|
||||
TrafficFlushSeconds int
|
||||
}
|
||||
|
||||
// GetVersion returns the version string of the 3x-ui application.
|
||||
func GetVersion() string {
|
||||
return strings.TrimSpace(version)
|
||||
|
|
@ -106,6 +121,14 @@ func GetSettingPath() string {
|
|||
return fmt.Sprintf("%s/%s.json", GetDBFolderPath(), GetName())
|
||||
}
|
||||
|
||||
func GetSharedCachePath() string {
|
||||
return filepath.Join(GetDBFolderPath(), "shared-cache.json")
|
||||
}
|
||||
|
||||
func GetTrafficPendingPath() string {
|
||||
return filepath.Join(GetDBFolderPath(), "traffic-pending.json")
|
||||
}
|
||||
|
||||
// GetLogFolder returns the path to the log folder based on environment variables or platform defaults.
|
||||
func GetLogFolder() string {
|
||||
logFolderPath := os.Getenv("XUI_LOG_FOLDER")
|
||||
|
|
@ -125,6 +148,14 @@ var settingGroupAliases = map[string][]string{
|
|||
"dbUser": {"databaseConnection", "other"},
|
||||
"dbPassword": {"databaseConnection", "other"},
|
||||
"dbName": {"databaseConnection", "other"},
|
||||
"nodeRole": {"other"},
|
||||
"nodeId": {"other"},
|
||||
"syncInterval": {
|
||||
"other",
|
||||
},
|
||||
"trafficFlushInterval": {
|
||||
"other",
|
||||
},
|
||||
}
|
||||
|
||||
func readGroupedString(settings map[string]any, key string) string {
|
||||
|
|
@ -143,6 +174,37 @@ func readGroupedString(settings map[string]any, key string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func readGroupedInt(settings map[string]any, key string, fallback int) int {
|
||||
readInt := func(value any) (int, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int(v), true
|
||||
case int:
|
||||
return v, true
|
||||
case string:
|
||||
i, err := strconv.Atoi(v)
|
||||
if err == nil {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if groups, ok := settingGroupAliases[key]; ok {
|
||||
for _, groupName := range groups {
|
||||
if group, ok := settings[groupName].(map[string]any); ok {
|
||||
if value, ok := readInt(group[key]); ok {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := readInt(settings[key]); ok {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func settingsLayoutMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"layout": "按模块-用途来归类",
|
||||
|
|
@ -250,6 +312,51 @@ func GetDBConfigFromJSON() DBConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetNodeConfigFromJSON() NodeConfig {
|
||||
data, err := os.ReadFile(GetSettingPath())
|
||||
if err != nil {
|
||||
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return NodeConfig{Role: NodeRoleMaster, SyncIntervalSeconds: 30, TrafficFlushSeconds: 10}
|
||||
}
|
||||
|
||||
role := readGroupedString(settings, "nodeRole")
|
||||
if role == "" {
|
||||
role = string(NodeRoleMaster)
|
||||
}
|
||||
|
||||
return NodeConfig{
|
||||
Role: NodeRole(role),
|
||||
NodeID: readGroupedString(settings, "nodeId"),
|
||||
SyncIntervalSeconds: readGroupedInt(settings, "syncInterval", 30),
|
||||
TrafficFlushSeconds: readGroupedInt(settings, "trafficFlushInterval", 10),
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateNodeConfig(nodeCfg NodeConfig, dbCfg DBConfig) error {
|
||||
switch nodeCfg.Role {
|
||||
case NodeRoleMaster, NodeRoleWorker:
|
||||
default:
|
||||
return fmt.Errorf("invalid nodeRole %q", nodeCfg.Role)
|
||||
}
|
||||
if nodeCfg.Role == NodeRoleWorker && nodeCfg.NodeID == "" {
|
||||
return fmt.Errorf("worker mode requires nodeId")
|
||||
}
|
||||
if nodeCfg.Role == NodeRoleWorker && dbCfg.Type != "mariadb" {
|
||||
return fmt.Errorf("worker mode requires mariadb")
|
||||
}
|
||||
if nodeCfg.SyncIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("syncInterval must be positive")
|
||||
}
|
||||
if nodeCfg.TrafficFlushSeconds <= 0 {
|
||||
return fmt.Errorf("trafficFlushInterval must be positive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteSettingToJSON writes a single setting key to the JSON config file.
|
||||
// It reads the existing file, updates the value, and writes back.
|
||||
func WriteSettingToJSON(key, value string) error {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package config
|
|||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -132,6 +133,72 @@ func TestGetLogFolderCustom(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func writeTestSettingsFile(t *testing.T, settings map[string]any) {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalIndent error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(GetSettingPath(), data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNodeConfigFromJSONDefaults(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||
writeTestSettingsFile(t, map[string]any{})
|
||||
|
||||
cfg := GetNodeConfigFromJSON()
|
||||
if cfg.Role != NodeRoleMaster {
|
||||
t.Fatalf("expected default role %q, got %q", NodeRoleMaster, cfg.Role)
|
||||
}
|
||||
if cfg.NodeID != "" {
|
||||
t.Fatalf("expected empty default node id, got %q", cfg.NodeID)
|
||||
}
|
||||
if cfg.SyncIntervalSeconds != 30 {
|
||||
t.Fatalf("expected default sync interval 30, got %d", cfg.SyncIntervalSeconds)
|
||||
}
|
||||
if cfg.TrafficFlushSeconds != 10 {
|
||||
t.Fatalf("expected default traffic flush interval 10, got %d", cfg.TrafficFlushSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNodeConfigWorkerRequiresNodeID(t *testing.T) {
|
||||
err := ValidateNodeConfig(NodeConfig{
|
||||
Role: NodeRoleWorker,
|
||||
SyncIntervalSeconds: 30,
|
||||
TrafficFlushSeconds: 10,
|
||||
}, DBConfig{Type: "mariadb"})
|
||||
if err == nil {
|
||||
t.Fatal("expected worker without node id to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNodeConfigWorkerRequiresMariaDB(t *testing.T) {
|
||||
err := ValidateNodeConfig(NodeConfig{
|
||||
Role: NodeRoleWorker,
|
||||
NodeID: "worker-1",
|
||||
SyncIntervalSeconds: 30,
|
||||
TrafficFlushSeconds: 10,
|
||||
}, DBConfig{Type: "sqlite"})
|
||||
if err == nil {
|
||||
t.Fatal("expected worker on sqlite to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedRuntimeFilePaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||
|
||||
if got := GetSharedCachePath(); got != filepath.Join(tmpDir, "shared-cache.json") {
|
||||
t.Fatalf("unexpected shared cache path: %s", got)
|
||||
}
|
||||
if got := GetTrafficPendingPath(); got != filepath.Join(tmpDir, "traffic-pending.json") {
|
||||
t.Fatalf("unexpected traffic pending path: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBConfigFromJSONSupportsModulePurposeLayout(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("XUI_DB_FOLDER", tmpDir)
|
||||
|
|
|
|||
82
main.go
82
main.go
|
|
@ -29,6 +29,12 @@ import (
|
|||
func runWebServer() {
|
||||
log.Printf("Starting %v %v", config.GetName(), config.GetVersion())
|
||||
|
||||
dbCfg := config.GetDBConfigFromJSON()
|
||||
nodeCfg := config.GetNodeConfigFromJSON()
|
||||
if err := config.ValidateNodeConfig(nodeCfg, dbCfg); err != nil {
|
||||
log.Fatalf("invalid node configuration: %v", err)
|
||||
}
|
||||
|
||||
switch config.GetLogLevel() {
|
||||
case config.Debug:
|
||||
logger.InitLogger(logging.DEBUG)
|
||||
|
|
@ -199,6 +205,11 @@ func showSetting(show bool) {
|
|||
fmt.Println("port:", port)
|
||||
fmt.Println("webDomain:", webDomain)
|
||||
fmt.Println("webBasePath:", webBasePath)
|
||||
nodeCfg := config.GetNodeConfigFromJSON()
|
||||
fmt.Println("nodeRole:", nodeCfg.Role)
|
||||
fmt.Println("nodeId:", nodeCfg.NodeID)
|
||||
fmt.Println("syncInterval:", nodeCfg.SyncIntervalSeconds)
|
||||
fmt.Println("trafficFlushInterval:", nodeCfg.TrafficFlushSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -525,6 +536,10 @@ func main() {
|
|||
var dbPassword string
|
||||
var dbName string
|
||||
var showDbType bool
|
||||
var nodeRoleFlag string
|
||||
var nodeIDFlag string
|
||||
var syncIntervalFlag int
|
||||
var trafficFlushIntervalFlag int
|
||||
settingCmd.StringVar(&dbTypeFlag, "dbType", "", "Set database type (sqlite or mariadb)")
|
||||
settingCmd.StringVar(&dbHost, "dbHost", "", "Set MariaDB host")
|
||||
settingCmd.StringVar(&dbPort, "dbPort", "", "Set MariaDB port")
|
||||
|
|
@ -532,6 +547,10 @@ func main() {
|
|||
settingCmd.StringVar(&dbPassword, "dbPassword", "", "Set MariaDB password")
|
||||
settingCmd.StringVar(&dbName, "dbName", "", "Set MariaDB database name")
|
||||
settingCmd.BoolVar(&showDbType, "showDbType", false, "Print current database type and exit")
|
||||
settingCmd.StringVar(&nodeRoleFlag, "nodeRole", "", "Set node role (master or worker)")
|
||||
settingCmd.StringVar(&nodeIDFlag, "nodeId", "", "Set node identifier")
|
||||
settingCmd.IntVar(&syncIntervalFlag, "syncInterval", 0, "Set shared sync interval in seconds")
|
||||
settingCmd.IntVar(&trafficFlushIntervalFlag, "trafficFlushInterval", 0, "Set traffic flush interval in seconds")
|
||||
|
||||
migrateDbCmd := flag.NewFlagSet("migrate-db", flag.ExitOnError)
|
||||
var migrateDirection string
|
||||
|
|
@ -582,6 +601,22 @@ func main() {
|
|||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
nodeRoleSet := false
|
||||
nodeIDSet := false
|
||||
syncIntervalSet := false
|
||||
trafficFlushIntervalSet := false
|
||||
settingCmd.Visit(func(f *flag.Flag) {
|
||||
switch f.Name {
|
||||
case "nodeRole":
|
||||
nodeRoleSet = true
|
||||
case "nodeId":
|
||||
nodeIDSet = true
|
||||
case "syncInterval":
|
||||
syncIntervalSet = true
|
||||
case "trafficFlushInterval":
|
||||
trafficFlushIntervalSet = true
|
||||
}
|
||||
})
|
||||
if showDbType {
|
||||
fmt.Println(config.GetDBTypeFromJSON())
|
||||
return
|
||||
|
|
@ -648,6 +683,53 @@ func main() {
|
|||
fmt.Println("dbName set to:", dbName)
|
||||
}
|
||||
}
|
||||
if nodeRoleSet || nodeIDSet || syncIntervalSet || trafficFlushIntervalSet {
|
||||
candidate := config.GetNodeConfigFromJSON()
|
||||
if nodeRoleSet {
|
||||
candidate.Role = config.NodeRole(nodeRoleFlag)
|
||||
}
|
||||
if nodeIDSet {
|
||||
candidate.NodeID = nodeIDFlag
|
||||
}
|
||||
if syncIntervalSet {
|
||||
candidate.SyncIntervalSeconds = syncIntervalFlag
|
||||
}
|
||||
if trafficFlushIntervalSet {
|
||||
candidate.TrafficFlushSeconds = trafficFlushIntervalFlag
|
||||
}
|
||||
if err := config.ValidateNodeConfig(candidate, config.GetDBConfigFromJSON()); err != nil {
|
||||
fmt.Println("Invalid node settings:", err)
|
||||
return
|
||||
}
|
||||
if nodeRoleSet {
|
||||
if err := config.WriteSettingToJSON("nodeRole", nodeRoleFlag); err != nil {
|
||||
fmt.Println("Failed to set nodeRole:", err)
|
||||
} else {
|
||||
fmt.Println("nodeRole set to:", nodeRoleFlag)
|
||||
}
|
||||
}
|
||||
if nodeIDSet {
|
||||
if err := config.WriteSettingToJSON("nodeId", nodeIDFlag); err != nil {
|
||||
fmt.Println("Failed to set nodeId:", err)
|
||||
} else {
|
||||
fmt.Println("nodeId set to:", nodeIDFlag)
|
||||
}
|
||||
}
|
||||
if syncIntervalSet {
|
||||
if err := config.WriteSettingToJSON("syncInterval", fmt.Sprintf("%d", syncIntervalFlag)); err != nil {
|
||||
fmt.Println("Failed to set syncInterval:", err)
|
||||
} else {
|
||||
fmt.Println("syncInterval set to:", syncIntervalFlag)
|
||||
}
|
||||
}
|
||||
if trafficFlushIntervalSet {
|
||||
if err := config.WriteSettingToJSON("trafficFlushInterval", fmt.Sprintf("%d", trafficFlushIntervalFlag)); err != nil {
|
||||
fmt.Println("Failed to set trafficFlushInterval:", err)
|
||||
} else {
|
||||
fmt.Println("trafficFlushInterval set to:", trafficFlushIntervalFlag)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "cert":
|
||||
err := settingCmd.Parse(os.Args[2:])
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Reference in a new issue