mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-01-13 17:24:05 +00:00
624 lines
17 KiB
Go
624 lines
17 KiB
Go
// Package service provides Node management service for multi-node architecture.
|
|
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/mhsanaei/3x-ui/v2/database"
|
|
"github.com/mhsanaei/3x-ui/v2/database/model"
|
|
"github.com/mhsanaei/3x-ui/v2/logger"
|
|
"github.com/mhsanaei/3x-ui/v2/xray"
|
|
)
|
|
|
|
// NodeService provides business logic for managing nodes in multi-node mode.
|
|
type NodeService struct{}
|
|
|
|
// GetAllNodes retrieves all nodes from the database.
|
|
func (s *NodeService) GetAllNodes() ([]*model.Node, error) {
|
|
db := database.GetDB()
|
|
var nodes []*model.Node
|
|
err := db.Find(&nodes).Error
|
|
return nodes, err
|
|
}
|
|
|
|
// GetNode retrieves a node by ID.
|
|
func (s *NodeService) GetNode(id int) (*model.Node, error) {
|
|
db := database.GetDB()
|
|
var node model.Node
|
|
err := db.First(&node, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// AddNode creates a new node.
|
|
func (s *NodeService) AddNode(node *model.Node) error {
|
|
db := database.GetDB()
|
|
return db.Create(node).Error
|
|
}
|
|
|
|
// UpdateNode updates an existing node.
|
|
// Only updates fields that are provided (non-empty for strings, non-zero for integers).
|
|
func (s *NodeService) UpdateNode(node *model.Node) error {
|
|
db := database.GetDB()
|
|
|
|
// Get existing node to preserve fields that are not being updated
|
|
existingNode, err := s.GetNode(node.Id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get existing node: %w", err)
|
|
}
|
|
|
|
// Update only provided fields
|
|
updates := make(map[string]interface{})
|
|
|
|
if node.Name != "" {
|
|
updates["name"] = node.Name
|
|
}
|
|
|
|
if node.Address != "" {
|
|
updates["address"] = node.Address
|
|
}
|
|
|
|
if node.ApiKey != "" {
|
|
updates["api_key"] = node.ApiKey
|
|
}
|
|
|
|
// Update status and last_check if provided (these are usually set by health checks, not user edits)
|
|
if node.Status != "" && node.Status != existingNode.Status {
|
|
updates["status"] = node.Status
|
|
}
|
|
|
|
if node.LastCheck > 0 && node.LastCheck != existingNode.LastCheck {
|
|
updates["last_check"] = node.LastCheck
|
|
}
|
|
|
|
// If no fields to update, return early
|
|
if len(updates) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Update only the specified fields
|
|
return db.Model(existingNode).Updates(updates).Error
|
|
}
|
|
|
|
// DeleteNode deletes a node by ID.
|
|
// This will cascade delete all InboundNodeMapping entries for this node.
|
|
func (s *NodeService) DeleteNode(id int) error {
|
|
db := database.GetDB()
|
|
|
|
// Delete all node mappings for this node (cascade delete)
|
|
err := db.Where("node_id = ?", id).Delete(&model.InboundNodeMapping{}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Delete the node itself
|
|
return db.Delete(&model.Node{}, id).Error
|
|
}
|
|
|
|
// CheckNodeHealth checks if a node is online and updates its status.
|
|
func (s *NodeService) CheckNodeHealth(node *model.Node) error {
|
|
status, err := s.CheckNodeStatus(node)
|
|
if err != nil {
|
|
node.Status = "error"
|
|
node.LastCheck = time.Now().Unix()
|
|
s.UpdateNode(node)
|
|
return err
|
|
}
|
|
node.Status = status
|
|
node.LastCheck = time.Now().Unix()
|
|
return s.UpdateNode(node)
|
|
}
|
|
|
|
// CheckNodeStatus performs a health check on a given node.
|
|
func (s *NodeService) CheckNodeStatus(node *model.Node) (string, error) {
|
|
client := &http.Client{
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/health", node.Address)
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
return "offline", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
return "online", nil
|
|
}
|
|
return "error", fmt.Errorf("node returned status code %d", resp.StatusCode)
|
|
}
|
|
|
|
// CheckAllNodesHealth checks health of all nodes.
|
|
func (s *NodeService) CheckAllNodesHealth() {
|
|
nodes, err := s.GetAllNodes()
|
|
if err != nil {
|
|
logger.Errorf("Failed to get nodes for health check: %v", err)
|
|
return
|
|
}
|
|
|
|
for _, node := range nodes {
|
|
go s.CheckNodeHealth(node)
|
|
}
|
|
}
|
|
|
|
// GetNodeForInbound returns the node assigned to an inbound, or nil if not assigned.
|
|
// Deprecated: Use GetNodesForInbound for multi-node support.
|
|
func (s *NodeService) GetNodeForInbound(inboundId int) (*model.Node, error) {
|
|
db := database.GetDB()
|
|
var mapping model.InboundNodeMapping
|
|
err := db.Where("inbound_id = ?", inboundId).First(&mapping).Error
|
|
if err != nil {
|
|
return nil, err // Not found is OK, means inbound is not assigned to any node
|
|
}
|
|
|
|
return s.GetNode(mapping.NodeId)
|
|
}
|
|
|
|
// GetNodesForInbound returns all nodes assigned to an inbound.
|
|
func (s *NodeService) GetNodesForInbound(inboundId int) ([]*model.Node, error) {
|
|
db := database.GetDB()
|
|
var mappings []model.InboundNodeMapping
|
|
err := db.Where("inbound_id = ?", inboundId).Find(&mappings).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
nodes := make([]*model.Node, 0, len(mappings))
|
|
for _, mapping := range mappings {
|
|
node, err := s.GetNode(mapping.NodeId)
|
|
if err == nil && node != nil {
|
|
nodes = append(nodes, node)
|
|
}
|
|
}
|
|
return nodes, nil
|
|
}
|
|
|
|
// GetInboundsForNode returns all inbounds assigned to a node.
|
|
func (s *NodeService) GetInboundsForNode(nodeId int) ([]*model.Inbound, error) {
|
|
db := database.GetDB()
|
|
var mappings []model.InboundNodeMapping
|
|
err := db.Where("node_id = ?", nodeId).Find(&mappings).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
inbounds := make([]*model.Inbound, 0, len(mappings))
|
|
for _, mapping := range mappings {
|
|
var inbound model.Inbound
|
|
err := db.First(&inbound, mapping.InboundId).Error
|
|
if err == nil {
|
|
inbounds = append(inbounds, &inbound)
|
|
}
|
|
}
|
|
return inbounds, nil
|
|
}
|
|
|
|
// NodeStatsResponse represents the response from node stats API.
|
|
type NodeStatsResponse struct {
|
|
Traffic []*NodeTraffic `json:"traffic"`
|
|
ClientTraffic []*NodeClientTraffic `json:"clientTraffic"`
|
|
OnlineClients []string `json:"onlineClients"`
|
|
}
|
|
|
|
// NodeTraffic represents traffic statistics from a node.
|
|
type NodeTraffic struct {
|
|
IsInbound bool `json:"isInbound"`
|
|
IsOutbound bool `json:"isOutbound"`
|
|
Tag string `json:"tag"`
|
|
Up int64 `json:"up"`
|
|
Down int64 `json:"down"`
|
|
}
|
|
|
|
// NodeClientTraffic represents client traffic statistics from a node.
|
|
type NodeClientTraffic struct {
|
|
Email string `json:"email"`
|
|
Up int64 `json:"up"`
|
|
Down int64 `json:"down"`
|
|
}
|
|
|
|
// GetNodeStats retrieves traffic and online clients statistics from a node.
|
|
func (s *NodeService) GetNodeStats(node *model.Node, reset bool) (*NodeStatsResponse, error) {
|
|
client := &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/api/v1/stats", node.Address)
|
|
if reset {
|
|
url += "?reset=true"
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+node.ApiKey)
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to request node stats: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("node returned status code %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var stats NodeStatsResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return &stats, nil
|
|
}
|
|
|
|
// CollectNodeStats collects statistics from all nodes and aggregates them into the database.
|
|
// This should be called periodically (e.g., via cron job).
|
|
func (s *NodeService) CollectNodeStats() error {
|
|
// Check if multi-node mode is enabled
|
|
settingService := SettingService{}
|
|
multiMode, err := settingService.GetMultiNodeMode()
|
|
if err != nil || !multiMode {
|
|
return nil // Skip if multi-node mode is not enabled
|
|
}
|
|
|
|
nodes, err := s.GetAllNodes()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get nodes: %w", err)
|
|
}
|
|
|
|
if len(nodes) == 0 {
|
|
return nil // No nodes to collect stats from
|
|
}
|
|
|
|
// Filter nodes: only collect stats from nodes that have assigned inbounds
|
|
nodesWithInbounds := make([]*model.Node, 0)
|
|
for _, node := range nodes {
|
|
inbounds, err := s.GetInboundsForNode(node.Id)
|
|
if err == nil && len(inbounds) > 0 {
|
|
// Only include nodes that have at least one assigned inbound
|
|
nodesWithInbounds = append(nodesWithInbounds, node)
|
|
}
|
|
}
|
|
|
|
if len(nodesWithInbounds) == 0 {
|
|
return nil // No nodes with assigned inbounds
|
|
}
|
|
|
|
// Import inbound service to aggregate traffic
|
|
inboundService := &InboundService{}
|
|
|
|
// Collect stats from nodes with assigned inbounds concurrently
|
|
type nodeStatsResult struct {
|
|
node *model.Node
|
|
stats *NodeStatsResponse
|
|
err error
|
|
}
|
|
|
|
results := make(chan nodeStatsResult, len(nodesWithInbounds))
|
|
for _, node := range nodesWithInbounds {
|
|
go func(n *model.Node) {
|
|
stats, err := s.GetNodeStats(n, false) // Don't reset counters on collection
|
|
results <- nodeStatsResult{node: n, stats: stats, err: err}
|
|
}(node)
|
|
}
|
|
|
|
// Aggregate all traffic
|
|
allTraffics := make([]*xray.Traffic, 0)
|
|
allClientTraffics := make([]*xray.ClientTraffic, 0)
|
|
onlineClientsMap := make(map[string]bool)
|
|
|
|
for i := 0; i < len(nodesWithInbounds); i++ {
|
|
result := <-results
|
|
if result.err != nil {
|
|
// Check if error is expected (XRAY not running, 404 for old nodes, etc.)
|
|
errMsg := result.err.Error()
|
|
if strings.Contains(errMsg, "XRAY is not running") ||
|
|
strings.Contains(errMsg, "status code 404") ||
|
|
strings.Contains(errMsg, "status code 500") {
|
|
// These are expected errors, log as debug only
|
|
logger.Debugf("Skipping stats collection from node %s (ID: %d): %v", result.node.Name, result.node.Id, result.err)
|
|
} else {
|
|
// Unexpected errors should be logged as warning
|
|
logger.Warningf("Failed to get stats from node %s (ID: %d): %v", result.node.Name, result.node.Id, result.err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if result.stats == nil {
|
|
continue
|
|
}
|
|
|
|
// Convert node traffic to xray.Traffic
|
|
for _, nt := range result.stats.Traffic {
|
|
allTraffics = append(allTraffics, &xray.Traffic{
|
|
IsInbound: nt.IsInbound,
|
|
IsOutbound: nt.IsOutbound,
|
|
Tag: nt.Tag,
|
|
Up: nt.Up,
|
|
Down: nt.Down,
|
|
})
|
|
}
|
|
|
|
// Convert node client traffic to xray.ClientTraffic
|
|
for _, nct := range result.stats.ClientTraffic {
|
|
allClientTraffics = append(allClientTraffics, &xray.ClientTraffic{
|
|
Email: nct.Email,
|
|
Up: nct.Up,
|
|
Down: nct.Down,
|
|
})
|
|
}
|
|
|
|
// Collect online clients
|
|
for _, email := range result.stats.OnlineClients {
|
|
onlineClientsMap[email] = true
|
|
}
|
|
}
|
|
|
|
// Aggregate traffic into database
|
|
if len(allTraffics) > 0 || len(allClientTraffics) > 0 {
|
|
_, needRestart := inboundService.AddTraffic(allTraffics, allClientTraffics)
|
|
if needRestart {
|
|
logger.Info("Traffic aggregation triggered client renewal/disabling, restart may be needed")
|
|
}
|
|
}
|
|
|
|
logger.Debugf("Collected stats from nodes: %d traffics, %d client traffics, %d online clients",
|
|
len(allTraffics), len(allClientTraffics), len(onlineClientsMap))
|
|
|
|
return nil
|
|
}
|
|
|
|
// AssignInboundToNode assigns an inbound to a node.
|
|
func (s *NodeService) AssignInboundToNode(inboundId, nodeId int) error {
|
|
db := database.GetDB()
|
|
mapping := &model.InboundNodeMapping{
|
|
InboundId: inboundId,
|
|
NodeId: nodeId,
|
|
}
|
|
return db.Save(mapping).Error
|
|
}
|
|
|
|
// AssignInboundToNodes assigns an inbound to multiple nodes.
|
|
func (s *NodeService) AssignInboundToNodes(inboundId int, nodeIds []int) error {
|
|
db := database.GetDB()
|
|
// First, remove all existing assignments
|
|
if err := db.Where("inbound_id = ?", inboundId).Delete(&model.InboundNodeMapping{}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Then, create new assignments
|
|
for _, nodeId := range nodeIds {
|
|
if nodeId > 0 {
|
|
mapping := &model.InboundNodeMapping{
|
|
InboundId: inboundId,
|
|
NodeId: nodeId,
|
|
}
|
|
if err := db.Create(mapping).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UnassignInboundFromNode removes the assignment of an inbound from its node.
|
|
func (s *NodeService) UnassignInboundFromNode(inboundId int) error {
|
|
db := database.GetDB()
|
|
return db.Where("inbound_id = ?", inboundId).Delete(&model.InboundNodeMapping{}).Error
|
|
}
|
|
|
|
// ApplyConfigToNode sends XRAY configuration to a node.
|
|
func (s *NodeService) ApplyConfigToNode(node *model.Node, xrayConfig []byte) error {
|
|
client := &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/api/v1/apply-config", node.Address)
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(xrayConfig))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", node.ApiKey))
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("node returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReloadNode reloads XRAY on a specific node.
|
|
func (s *NodeService) ReloadNode(node *model.Node) error {
|
|
client := &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/api/v1/reload", node.Address)
|
|
req, err := http.NewRequest("POST", url, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", node.ApiKey))
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("node returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ForceReloadNode forcefully reloads XRAY on a specific node (even if hung).
|
|
func (s *NodeService) ForceReloadNode(node *model.Node) error {
|
|
client := &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/api/v1/force-reload", node.Address)
|
|
req, err := http.NewRequest("POST", url, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", node.ApiKey))
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("node returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReloadAllNodes reloads XRAY on all nodes.
|
|
func (s *NodeService) ReloadAllNodes() error {
|
|
nodes, err := s.GetAllNodes()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get nodes: %w", err)
|
|
}
|
|
|
|
type reloadResult struct {
|
|
node *model.Node
|
|
err error
|
|
}
|
|
|
|
results := make(chan reloadResult, len(nodes))
|
|
for _, node := range nodes {
|
|
go func(n *model.Node) {
|
|
err := s.ForceReloadNode(n) // Use force reload to handle hung nodes
|
|
results <- reloadResult{node: n, err: err}
|
|
}(node)
|
|
}
|
|
|
|
var errors []string
|
|
for i := 0; i < len(nodes); i++ {
|
|
result := <-results
|
|
if result.err != nil {
|
|
errors = append(errors, fmt.Sprintf("node %d (%s): %v", result.node.Id, result.node.Name, result.err))
|
|
}
|
|
}
|
|
|
|
if len(errors) > 0 {
|
|
return fmt.Errorf("failed to reload some nodes: %s", strings.Join(errors, "; "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateApiKey validates the API key by making a test request to the node.
|
|
func (s *NodeService) ValidateApiKey(node *model.Node) error {
|
|
client := &http.Client{
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// First, check if node is reachable via health endpoint
|
|
healthURL := fmt.Sprintf("%s/health", node.Address)
|
|
healthResp, err := client.Get(healthURL)
|
|
if err != nil {
|
|
logger.Errorf("Failed to connect to node %s at %s: %v", node.Address, healthURL, err)
|
|
return fmt.Errorf("failed to connect to node: %v", err)
|
|
}
|
|
healthResp.Body.Close()
|
|
|
|
if healthResp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("node health check failed with status %d", healthResp.StatusCode)
|
|
}
|
|
|
|
// Try to get node status - this will validate the API key
|
|
url := fmt.Sprintf("%s/api/v1/status", node.Address)
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create request: %v", err)
|
|
}
|
|
|
|
authHeader := fmt.Sprintf("Bearer %s", node.ApiKey)
|
|
req.Header.Set("Authorization", authHeader)
|
|
|
|
logger.Debugf("Validating API key for node %s at %s (key: %s)", node.Name, url, node.ApiKey)
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
logger.Errorf("Failed to connect to node %s: %v", node.Address, err)
|
|
return fmt.Errorf("failed to connect to node: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
logger.Warningf("Invalid API key for node %s (sent: %s): %s", node.Address, authHeader, string(body))
|
|
return fmt.Errorf("invalid API key")
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
logger.Errorf("Node %s returned status %d: %s", node.Address, resp.StatusCode, string(body))
|
|
return fmt.Errorf("node returned status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
logger.Debugf("API key validated successfully for node %s", node.Name)
|
|
return nil
|
|
}
|
|
|
|
// GetNodeStatus retrieves the status of a node.
|
|
func (s *NodeService) GetNodeStatus(node *model.Node) (map[string]interface{}, error) {
|
|
client := &http.Client{
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/api/v1/status", node.Address)
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", node.ApiKey))
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("node returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
var status map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return status, nil
|
|
}
|