From 3cfa554786f438433e26199b7781564bfc293e3e Mon Sep 17 00:00:00 2001 From: Sora39831 <540587985@qq.com> Date: Fri, 10 Apr 2026 15:15:42 +0800 Subject: [PATCH] feat: add cache-backed worker sync and heartbeat loops --- web/service/node_cache.go | 46 +++++++ web/service/node_sync.go | 148 +++++++++++++++++++++ web/service/node_sync_test.go | 239 ++++++++++++++++++++++++++++++++++ web/service/xray.go | 217 +++++++++++++++++------------- web/web.go | 15 +++ 5 files changed, 575 insertions(+), 90 deletions(-) create mode 100644 web/service/node_cache.go create mode 100644 web/service/node_sync.go create mode 100644 web/service/node_sync_test.go diff --git a/web/service/node_cache.go b/web/service/node_cache.go new file mode 100644 index 00000000..4975458d --- /dev/null +++ b/web/service/node_cache.go @@ -0,0 +1,46 @@ +package service + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + + "github.com/mhsanaei/3x-ui/v2/database/model" +) + +type SharedAccountsSnapshot struct { + Version int64 `json:"version"` + Inbounds []*model.Inbound `json:"inbounds"` +} + +func LoadSharedAccountsSnapshot(path string) (*SharedAccountsSnapshot, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + snapshot := &SharedAccountsSnapshot{} + if err := json.Unmarshal(data, snapshot); err != nil { + return nil, err + } + if snapshot.Inbounds == nil { + snapshot.Inbounds = []*model.Inbound{} + } + return snapshot, nil +} + +func SaveSharedAccountsSnapshot(path string, snapshot *SharedAccountsSnapshot) error { + if snapshot == nil { + return errors.New("shared snapshot is nil") + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + + data, err := json.MarshalIndent(snapshot, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0644) +} diff --git a/web/service/node_sync.go b/web/service/node_sync.go new file mode 100644 index 00000000..5343fa4e --- /dev/null +++ b/web/service/node_sync.go @@ -0,0 +1,148 @@ +package service + +import ( + "context" + "errors" + "os" + "time" + + "github.com/mhsanaei/3x-ui/v2/config" + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" +) + +type NodeSyncService struct { + xrayService XrayService + cachePath string + lastSeenVersion int64 + loadVersion func() (int64, error) + loadSnapshot func() (*SharedAccountsSnapshot, error) + applySnapshot func(*SharedAccountsSnapshot) error +} + +func NewNodeSyncService() *NodeSyncService { + svc := &NodeSyncService{ + cachePath: config.GetSharedCachePath(), + } + svc.loadVersion = func() (int64, error) { + return database.GetSharedAccountsVersion(database.GetDB()) + } + svc.loadSnapshot = func() (*SharedAccountsSnapshot, error) { + inbounds, err := svc.xrayService.inboundService.GetAllInbounds() + if err != nil { + return nil, err + } + return &SharedAccountsSnapshot{Inbounds: inbounds}, nil + } + svc.applySnapshot = svc.xrayService.ApplySharedSnapshot + return svc +} + +func (s *NodeSyncService) updateNodeState(version int64, syncErr error, didSync bool) { + nodeCfg := config.GetNodeConfigFromJSON() + now := time.Now().Unix() + state := &model.NodeState{} + if nodeCfg.NodeID != "" { + _ = database.GetDB().First(state, "node_id = ?", nodeCfg.NodeID).Error + } + state.NodeID = nodeCfg.NodeID + state.NodeRole = string(nodeCfg.Role) + state.LastHeartbeatAt = now + state.LastSeenVersion = version + if didSync { + state.LastSyncAt = now + } + if syncErr != nil { + state.LastError = syncErr.Error() + } else { + state.LastError = "" + } + _ = database.UpsertNodeState(database.GetDB(), state) +} + +func (s *NodeSyncService) BootstrapFromCache() error { + snapshot, err := LoadSharedAccountsSnapshot(s.cachePath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + if snapshot == nil { + return errors.New("shared snapshot is nil") + } + if err := s.applySnapshot(snapshot); err != nil { + return err + } + s.lastSeenVersion = snapshot.Version + return nil +} + +func (s *NodeSyncService) SyncOnce() (bool, error) { + version, err := s.loadVersion() + if err != nil { + s.updateNodeState(s.lastSeenVersion, err, false) + return false, err + } + if version == s.lastSeenVersion { + s.updateNodeState(version, nil, false) + return false, nil + } + + snapshot, err := s.loadSnapshot() + if err != nil { + s.updateNodeState(s.lastSeenVersion, err, false) + return false, err + } + if snapshot == nil { + err = errors.New("shared snapshot is nil") + s.updateNodeState(s.lastSeenVersion, err, false) + return false, err + } + + snapshot.Version = version + if err := SaveSharedAccountsSnapshot(s.cachePath, snapshot); err != nil { + s.updateNodeState(s.lastSeenVersion, err, false) + return false, err + } + if err := s.applySnapshot(snapshot); err != nil { + s.updateNodeState(s.lastSeenVersion, err, false) + return false, err + } + + s.lastSeenVersion = version + s.updateNodeState(version, nil, true) + return true, nil +} + +func (s *NodeSyncService) Run(ctx context.Context, interval time.Duration) { + _ = s.BootstrapFromCache() + _, _ = s.SyncOnce() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _, _ = s.SyncOnce() + } + } +} + +func (s *NodeSyncService) RunHeartbeatLoop(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + version, _ := database.GetSharedAccountsVersion(database.GetDB()) + s.updateNodeState(version, nil, false) + } + } +} diff --git a/web/service/node_sync_test.go b/web/service/node_sync_test.go new file mode 100644 index 00000000..33995d92 --- /dev/null +++ b/web/service/node_sync_test.go @@ -0,0 +1,239 @@ +package service + +import ( + "path/filepath" + "testing" + + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" +) + +func writeNodeSyncSettings(t *testing.T, nodeID string) { + t.Helper() + writeNodeGuardSettings(t, map[string]any{ + "dbType": "mariadb", + "nodeRole": "worker", + "nodeId": nodeID, + }) +} + +func loadNodeState(t *testing.T, nodeID string) *model.NodeState { + t.Helper() + state := &model.NodeState{} + if err := database.GetDB().First(state, "node_id = ?", nodeID).Error; err != nil { + t.Fatalf("load node state error: %v", err) + } + return state +} + +func TestLoadAndSaveSharedAccountsSnapshot(t *testing.T) { + setupTestDB(t) + + cachePath := filepath.Join(t.TempDir(), "shared-cache.json") + snapshot := &SharedAccountsSnapshot{ + Version: 7, + Inbounds: []*model.Inbound{ + { + Id: 11, + Enable: true, + Port: 443, + Protocol: model.VLESS, + Settings: `{"clients":[{"id":"u-1","email":"alice@example.com"}]}`, + }, + }, + } + + if err := SaveSharedAccountsSnapshot(cachePath, snapshot); err != nil { + t.Fatalf("SaveSharedAccountsSnapshot error: %v", err) + } + + loaded, err := LoadSharedAccountsSnapshot(cachePath) + if err != nil { + t.Fatalf("LoadSharedAccountsSnapshot error: %v", err) + } + if loaded.Version != snapshot.Version { + t.Fatalf("expected version %d, got %d", snapshot.Version, loaded.Version) + } + if len(loaded.Inbounds) != 1 || loaded.Inbounds[0].Tag != snapshot.Inbounds[0].Tag { + t.Fatalf("expected one inbound to round-trip") + } +} + +func TestSyncOnceSkipsApplyWhenVersionUnchanged(t *testing.T) { + setupTestDB(t) + writeNodeSyncSettings(t, "worker-skip") + + cachePath := filepath.Join(t.TempDir(), "shared-cache.json") + applyCalled := false + loadSnapshotCalled := false + syncSvc := &NodeSyncService{ + cachePath: cachePath, + lastSeenVersion: 9, + loadVersion: func() (int64, error) { return 9, nil }, + loadSnapshot: func() (*SharedAccountsSnapshot, error) { loadSnapshotCalled = true; return nil, nil }, + applySnapshot: func(*SharedAccountsSnapshot) error { applyCalled = true; return nil }, + } + + didSync, err := syncSvc.SyncOnce() + if err != nil { + t.Fatalf("SyncOnce error: %v", err) + } + if didSync { + t.Fatal("expected unchanged version to skip sync") + } + if loadSnapshotCalled { + t.Fatal("loadSnapshot should not be called when version is unchanged") + } + if applyCalled { + t.Fatal("applySnapshot should not be called when version is unchanged") + } + + state := loadNodeState(t, "worker-skip") + if state.LastSeenVersion != 9 { + t.Fatalf("expected last seen version 9, got %d", state.LastSeenVersion) + } + if state.LastSyncAt != 0 { + t.Fatalf("expected LastSyncAt to remain 0, got %d", state.LastSyncAt) + } + if state.LastHeartbeatAt == 0 { + t.Fatal("expected heartbeat timestamp to be recorded") + } +} + +func TestSyncOnceRefreshesCacheAndAppliesSnapshot(t *testing.T) { + setupTestDB(t) + writeNodeSyncSettings(t, "worker-refresh") + + cachePath := filepath.Join(t.TempDir(), "shared-cache.json") + wantSnapshot := &SharedAccountsSnapshot{ + Version: 12, + Inbounds: []*model.Inbound{ + { + Id: 100, + Enable: true, + Port: 8443, + Protocol: model.VLESS, + Settings: `{"clients":[]}`, + StreamSettings: `{"network":"tcp","tlsSettings":{"settings":{"allowInsecure":true}}}`, + Tag: "in-100", + }, + }, + } + + applyCalls := 0 + syncSvc := &NodeSyncService{ + cachePath: cachePath, + lastSeenVersion: 11, + loadVersion: func() (int64, error) { return 12, nil }, + loadSnapshot: func() (*SharedAccountsSnapshot, error) { return wantSnapshot, nil }, + applySnapshot: func(snapshot *SharedAccountsSnapshot) error { applyCalls++; return nil }, + } + + didSync, err := syncSvc.SyncOnce() + if err != nil { + t.Fatalf("SyncOnce error: %v", err) + } + if !didSync { + t.Fatal("expected sync to run when version changes") + } + if applyCalls != 1 { + t.Fatalf("expected applySnapshot to be called once, got %d", applyCalls) + } + if syncSvc.lastSeenVersion != 12 { + t.Fatalf("expected lastSeenVersion to become 12, got %d", syncSvc.lastSeenVersion) + } + + cached, err := LoadSharedAccountsSnapshot(cachePath) + if err != nil { + t.Fatalf("LoadSharedAccountsSnapshot error: %v", err) + } + if cached.Version != 12 { + t.Fatalf("expected cached version 12, got %d", cached.Version) + } + + state := loadNodeState(t, "worker-refresh") + if state.LastSeenVersion != 12 { + t.Fatalf("expected last seen version 12, got %d", state.LastSeenVersion) + } + if state.LastSyncAt == 0 { + t.Fatal("expected LastSyncAt to be recorded after successful sync") + } + if state.LastError != "" { + t.Fatalf("expected empty LastError, got %q", state.LastError) + } +} + +func TestSyncOncePreservesLastSyncAtWhenVersionUnchanged(t *testing.T) { + setupTestDB(t) + writeNodeSyncSettings(t, "worker-preserve") + + if err := database.UpsertNodeState(database.GetDB(), &model.NodeState{ + NodeID: "worker-preserve", + NodeRole: "worker", + LastSyncAt: 12345, + LastHeartbeatAt: 12345, + LastSeenVersion: 8, + }); err != nil { + t.Fatalf("UpsertNodeState error: %v", err) + } + + syncSvc := &NodeSyncService{ + cachePath: filepath.Join(t.TempDir(), "shared-cache.json"), + lastSeenVersion: 8, + loadVersion: func() (int64, error) { return 8, nil }, + loadSnapshot: func() (*SharedAccountsSnapshot, error) { return nil, nil }, + applySnapshot: func(*SharedAccountsSnapshot) error { return nil }, + } + + didSync, err := syncSvc.SyncOnce() + if err != nil { + t.Fatalf("SyncOnce error: %v", err) + } + if didSync { + t.Fatal("expected unchanged version to skip sync") + } + + state := loadNodeState(t, "worker-preserve") + if state.LastSyncAt != 12345 { + t.Fatalf("expected LastSyncAt to remain 12345, got %d", state.LastSyncAt) + } +} + +func TestBootstrapFromCacheAppliesCachedSnapshot(t *testing.T) { + setupTestDB(t) + writeNodeSyncSettings(t, "worker-bootstrap") + + cachePath := filepath.Join(t.TempDir(), "shared-cache.json") + cached := &SharedAccountsSnapshot{ + Version: 77, + Inbounds: []*model.Inbound{ + { + Id: 77, + Enable: true, + Port: 10077, + Protocol: model.VLESS, + Settings: `{"clients":[]}`, + Tag: "cache-77", + }, + }, + } + if err := SaveSharedAccountsSnapshot(cachePath, cached); err != nil { + t.Fatalf("SaveSharedAccountsSnapshot error: %v", err) + } + + appliedVersion := int64(0) + syncSvc := &NodeSyncService{ + cachePath: cachePath, + applySnapshot: func(snapshot *SharedAccountsSnapshot) error { + appliedVersion = snapshot.Version + return nil + }, + } + + if err := syncSvc.BootstrapFromCache(); err != nil { + t.Fatalf("BootstrapFromCache error: %v", err) + } + if appliedVersion != 77 { + t.Fatalf("expected cached version 77 to be applied, got %d", appliedVersion) + } +} diff --git a/web/service/xray.go b/web/service/xray.go index b9d4a74a..68d12862 100644 --- a/web/service/xray.go +++ b/web/service/xray.go @@ -6,6 +6,7 @@ import ( "runtime" "sync" + "github.com/mhsanaei/3x-ui/v2/database/model" "github.com/mhsanaei/3x-ui/v2/logger" "github.com/mhsanaei/3x-ui/v2/xray" @@ -98,8 +99,87 @@ func RemoveIndex(s []any, index int) []any { return append(s[:index], s[index+1:]...) } -// GetXrayConfig retrieves and builds the Xray configuration from settings and inbounds. -func (s *XrayService) GetXrayConfig() (*xray.Config, error) { +func normalizeInboundForXray(inbound *model.Inbound) error { + // get settings clients + settings := map[string]any{} + json.Unmarshal([]byte(inbound.Settings), &settings) + clients, ok := settings["clients"].([]any) + if ok { + // check users active or not + clientStats := inbound.ClientStats + for _, clientTraffic := range clientStats { + indexDecrease := 0 + for index, client := range clients { + c := client.(map[string]any) + if c["email"] == clientTraffic.Email { + if !clientTraffic.Enable { + clients = RemoveIndex(clients, index-indexDecrease) + indexDecrease++ + logger.Infof("Remove Inbound User %s due to expiration or traffic limit", c["email"]) + } + } + } + } + + // clear client config for additional parameters + var finalClients []any + for _, client := range clients { + c := client.(map[string]any) + if c["enable"] != nil { + if enable, ok := c["enable"].(bool); ok && !enable { + continue + } + } + for key := range c { + if key != "email" && key != "id" && key != "password" && key != "flow" && key != "method" { + delete(c, key) + } + if c["flow"] == "xtls-rprx-vision-udp443" { + c["flow"] = "xtls-rprx-vision" + } + } + finalClients = append(finalClients, any(c)) + } + + settings["clients"] = finalClients + modifiedSettings, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err + } + + inbound.Settings = string(modifiedSettings) + } + + if len(inbound.StreamSettings) > 0 { + // Unmarshal stream JSON + var stream map[string]any + json.Unmarshal([]byte(inbound.StreamSettings), &stream) + + // Remove the "settings" field under "tlsSettings" and "realitySettings" + tlsSettings, ok1 := stream["tlsSettings"].(map[string]any) + realitySettings, ok2 := stream["realitySettings"].(map[string]any) + if ok1 || ok2 { + if ok1 { + delete(tlsSettings, "settings") + } else if ok2 { + delete(realitySettings, "settings") + } + } + + delete(stream, "externalProxy") + + newStream, err := json.MarshalIndent(stream, "", " ") + if err != nil { + return err + } + inbound.StreamSettings = string(newStream) + } + + return nil +} + +// BuildConfigFromInbounds builds Xray config from an explicit inbound list. +func (s *XrayService) BuildConfigFromInbounds(inbounds []*model.Inbound) (*xray.Config, error) { templateConfig, err := s.settingService.GetXrayConfigTemplate() if err != nil { return nil, err @@ -111,95 +191,33 @@ func (s *XrayService) GetXrayConfig() (*xray.Config, error) { return nil, err } + for _, inbound := range inbounds { + if inbound == nil { + continue + } + if !inbound.Enable { + continue + } + inboundCopy := *inbound + if err := normalizeInboundForXray(&inboundCopy); err != nil { + return nil, err + } + + inboundConfig := inboundCopy.GenXrayInboundConfig() + xrayConfig.InboundConfigs = append(xrayConfig.InboundConfigs, *inboundConfig) + } + return xrayConfig, nil +} + +// GetXrayConfig retrieves and builds the Xray configuration from settings and inbounds. +func (s *XrayService) GetXrayConfig() (*xray.Config, error) { s.inboundService.AddTraffic(nil, nil) inbounds, err := s.inboundService.GetAllInbounds() if err != nil { return nil, err } - for _, inbound := range inbounds { - if !inbound.Enable { - continue - } - // get settings clients - settings := map[string]any{} - json.Unmarshal([]byte(inbound.Settings), &settings) - clients, ok := settings["clients"].([]any) - if ok { - // check users active or not - clientStats := inbound.ClientStats - for _, clientTraffic := range clientStats { - indexDecrease := 0 - for index, client := range clients { - c := client.(map[string]any) - if c["email"] == clientTraffic.Email { - if !clientTraffic.Enable { - clients = RemoveIndex(clients, index-indexDecrease) - indexDecrease++ - logger.Infof("Remove Inbound User %s due to expiration or traffic limit", c["email"]) - } - } - } - } - - // clear client config for additional parameters - var final_clients []any - for _, client := range clients { - c := client.(map[string]any) - if c["enable"] != nil { - if enable, ok := c["enable"].(bool); ok && !enable { - continue - } - } - for key := range c { - if key != "email" && key != "id" && key != "password" && key != "flow" && key != "method" { - delete(c, key) - } - if c["flow"] == "xtls-rprx-vision-udp443" { - c["flow"] = "xtls-rprx-vision" - } - } - final_clients = append(final_clients, any(c)) - } - - settings["clients"] = final_clients - modifiedSettings, err := json.MarshalIndent(settings, "", " ") - if err != nil { - return nil, err - } - - inbound.Settings = string(modifiedSettings) - } - - if len(inbound.StreamSettings) > 0 { - // Unmarshal stream JSON - var stream map[string]any - json.Unmarshal([]byte(inbound.StreamSettings), &stream) - - // Remove the "settings" field under "tlsSettings" and "realitySettings" - tlsSettings, ok1 := stream["tlsSettings"].(map[string]any) - realitySettings, ok2 := stream["realitySettings"].(map[string]any) - if ok1 || ok2 { - if ok1 { - delete(tlsSettings, "settings") - } else if ok2 { - delete(realitySettings, "settings") - } - } - - delete(stream, "externalProxy") - - newStream, err := json.MarshalIndent(stream, "", " ") - if err != nil { - return nil, err - } - inbound.StreamSettings = string(newStream) - } - - inboundConfig := inbound.GenXrayInboundConfig() - xrayConfig.InboundConfigs = append(xrayConfig.InboundConfigs, *inboundConfig) - } - return xrayConfig, nil + return s.BuildConfigFromInbounds(inbounds) } // GetXrayTraffic fetches the current traffic statistics from the running Xray process. @@ -223,14 +241,22 @@ func (s *XrayService) GetXrayTraffic() ([]*xray.Traffic, []*xray.ClientTraffic, // RestartXray restarts the Xray process, optionally forcing a restart even if config unchanged. func (s *XrayService) RestartXray(isForce bool) error { + xrayConfig, err := s.GetXrayConfig() + if err != nil { + return err + } + return s.RestartXrayWithConfig(xrayConfig, isForce) +} + +// RestartXrayWithConfig restarts Xray with a pre-built config. +func (s *XrayService) RestartXrayWithConfig(xrayConfig *xray.Config, isForce bool) error { lock.Lock() defer lock.Unlock() logger.Debug("restart Xray, force:", isForce) isManuallyStopped.Store(false) - xrayConfig, err := s.GetXrayConfig() - if err != nil { - return err + if xrayConfig == nil { + return errors.New("xray config is nil") } if s.IsXrayRunning() { @@ -243,14 +269,25 @@ func (s *XrayService) RestartXray(isForce bool) error { p = xray.NewProcess(xrayConfig) result = "" - err = p.Start() - if err != nil { + if err := p.Start(); err != nil { return err } return nil } +// ApplySharedSnapshot builds and applies config from a shared snapshot. +func (s *XrayService) ApplySharedSnapshot(snapshot *SharedAccountsSnapshot) error { + if snapshot == nil { + return errors.New("shared snapshot is nil") + } + xrayConfig, err := s.BuildConfigFromInbounds(snapshot.Inbounds) + if err != nil { + return err + } + return s.RestartXrayWithConfig(xrayConfig, false) +} + // StopXray stops the running Xray process. func (s *XrayService) StopXray() error { lock.Lock() diff --git a/web/web.go b/web/web.go index 968841f3..1facb260 100644 --- a/web/web.go +++ b/web/web.go @@ -410,6 +410,20 @@ func (s *Server) startTask() { } } +func (s *Server) startNodeLoops() { + nodeCfg := config.GetNodeConfigFromJSON() + nodeSyncService := service.NewNodeSyncService() + interval := time.Duration(nodeCfg.SyncIntervalSeconds) * time.Second + + if nodeCfg.Role == config.NodeRoleWorker { + go nodeSyncService.Run(s.ctx, interval) + return + } + if nodeCfg.NodeID != "" { + go nodeSyncService.RunHeartbeatLoop(s.ctx, interval) + } +} + // Start initializes and starts the web server with configured settings, routes, and background jobs. func (s *Server) Start() (err error) { // This is an anonymous function, no function name @@ -479,6 +493,7 @@ func (s *Server) Start() (err error) { }() s.startTask() + s.startNodeLoops() isTgbotenabled, err := s.settingService.GetTgbotEnabled() if (err == nil) && (isTgbotenabled) {