From a196dcddb05ca919facf938cf308cd42aa011f1a Mon Sep 17 00:00:00 2001 From: Konstantin Pichugin Date: Mon, 12 Jan 2026 00:57:04 +0300 Subject: [PATCH] add redis for cache(speed up panel) --- go.mod | 5 + go.sum | 10 ++ main.go | 7 ++ web/cache/cache.go | 123 ++++++++++++++++++ web/cache/redis.go | 137 ++++++++++++++++++++ web/cache/redisstore.go | 176 ++++++++++++++++++++++++++ web/middleware/cache.go | 132 ++++++++++++++++++++ web/service/client.go | 104 +++++++++------- web/service/inbound.go | 108 ++++++++++------ web/service/setting.go | 269 ++++++++++++++++++++++++---------------- web/web.go | 32 ++++- web/websocket/hub.go | 63 +++++++++- 12 files changed, 968 insertions(+), 198 deletions(-) create mode 100644 web/cache/cache.go create mode 100644 web/cache/redis.go create mode 100644 web/cache/redisstore.go create mode 100644 web/middleware/cache.go diff --git a/go.mod b/go.mod index 475727d9..73e56710 100644 --- a/go.mod +++ b/go.mod @@ -36,13 +36,16 @@ require ( require ( github.com/Azure/go-ntlmssp v0.1.0 // indirect + github.com/alicebob/miniredis/v2 v2.35.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.6.2 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/ebitengine/purego v0.9.1 // indirect github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/gin-contrib/sse v1.1.0 // indirect @@ -74,6 +77,7 @@ require ( github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.58.0 // indirect + github.com/redis/go-redis/v9 v9.17.2 // indirect github.com/refraction-networking/utls v1.8.1 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect @@ -90,6 +94,7 @@ require ( github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/xtls/reality v0.0.0-20251116175510-cd53f7d50237 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect golang.org/x/arch v0.23.0 // indirect diff --git a/go.sum b/go.sum index 70c47d08..ba917e82 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= +github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= @@ -12,6 +14,8 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudflare/circl v1.6.2 h1:hL7VBpHHKzrV5WTfHCaBsgx/HGbBYlgrwvNXEVDYYsQ= github.com/cloudflare/circl v1.6.2/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= @@ -22,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 h1:ucRHb6/lvW/+mTEIGbvhcYU3S8+uSNkuMjx/qZFfhtM= github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A= github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= @@ -149,6 +155,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/quic-go v0.58.0 h1:ggY2pvZaVdB9EyojxL1p+5mptkuHyX5MOSv4dgWF4Ug= github.com/quic-go/quic-go v0.58.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= +github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= @@ -207,6 +215,8 @@ github.com/xtls/xray-core v1.251208.0 h1:9jIXi+9KXnfmT5esSYNf9VAQlQkaAP8bG413B0e github.com/xtls/xray-core v1.251208.0/go.mod h1:kclzboEF0g6VBrp9/NXm8C0Aj64SDBt52OfthH1LSr4= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= diff --git a/main.go b/main.go index 8096616c..5096daa6 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,13 @@ func runWebServer() { log.Fatalf("Error initializing database: %v", err) } + // Initialize Redis cache (embedded mode by default) + err = web.InitRedisCache("") + if err != nil { + log.Fatalf("Error initializing Redis cache: %v", err) + } + defer web.CloseRedisCache() + var server *web.Server server = web.NewServer() global.SetWebServer(server) diff --git a/web/cache/cache.go b/web/cache/cache.go new file mode 100644 index 00000000..a6b233d0 --- /dev/null +++ b/web/cache/cache.go @@ -0,0 +1,123 @@ +// Package cache provides caching utilities with JSON serialization support. +package cache + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/mhsanaei/3x-ui/v2/logger" +) + +const ( + // Default TTL values + TTLInbounds = 30 * time.Second + TTLClients = 30 * time.Second + TTLSettings = 5 * time.Minute + TTLSetting = 10 * time.Minute // Increased from 5 to 10 minutes for better cache hit rate +) + +// Cache keys +const ( + KeyInboundsPrefix = "inbounds:user:" + KeyClientsPrefix = "clients:user:" + KeySettingsAll = "settings:all" + KeySettingPrefix = "setting:" +) + +// GetJSON retrieves a value from cache and unmarshals it as JSON. +func GetJSON(key string, dest interface{}) error { + val, err := Get(key) + if err != nil { + // Check if it's a "key not found" error (redis.Nil) + // This is expected and not a real error + if err.Error() == "redis: nil" { + return fmt.Errorf("key not found: %s", key) + } + return err + } + if val == "" { + return fmt.Errorf("empty value for key: %s", key) + } + return json.Unmarshal([]byte(val), dest) +} + +// SetJSON marshals a value as JSON and stores it in cache. +func SetJSON(key string, value interface{}, expiration time.Duration) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal value: %w", err) + } + return Set(key, string(data), expiration) +} + +// GetOrSet retrieves a value from cache, or computes it using fn if not found. +func GetOrSet(key string, dest interface{}, expiration time.Duration, fn func() (interface{}, error)) error { + // Try to get from cache + err := GetJSON(key, dest) + if err == nil { + logger.Debugf("Cache hit for key: %s", key) + return nil + } + + // Cache miss, compute value + logger.Debugf("Cache miss for key: %s", key) + value, err := fn() + if err != nil { + return err + } + + // Store in cache + if err := SetJSON(key, value, expiration); err != nil { + logger.Warningf("Failed to set cache for key %s: %v", key, err) + } + + // Copy value to dest + data, err := json.Marshal(value) + if err != nil { + return err + } + return json.Unmarshal(data, dest) +} + +// InvalidateInbounds invalidates all inbounds cache for a user. +func InvalidateInbounds(userId int) error { + pattern := fmt.Sprintf("%s%d", KeyInboundsPrefix, userId) + return DeletePattern(pattern) +} + +// InvalidateAllInbounds invalidates all inbounds cache. +func InvalidateAllInbounds() error { + pattern := KeyInboundsPrefix + "*" + return DeletePattern(pattern) +} + +// InvalidateClients invalidates all clients cache for a user. +func InvalidateClients(userId int) error { + pattern := fmt.Sprintf("%s%d", KeyClientsPrefix, userId) + return DeletePattern(pattern) +} + +// InvalidateAllClients invalidates all clients cache. +func InvalidateAllClients() error { + pattern := KeyClientsPrefix + "*" + return DeletePattern(pattern) +} + +// InvalidateSetting invalidates a specific setting cache. +// Note: We don't invalidate KeySettingsAll here to avoid unnecessary cache misses. +// KeySettingsAll will be invalidated only when settings are actually changed. +func InvalidateSetting(key string) error { + settingKey := KeySettingPrefix + key + return Delete(settingKey) +} + +// InvalidateAllSettings invalidates all settings cache. +func InvalidateAllSettings() error { + if err := Delete(KeySettingsAll); err != nil { + return err + } + // Also invalidate all individual settings + pattern := KeySettingPrefix + "*" + return DeletePattern(pattern) +} diff --git a/web/cache/redis.go b/web/cache/redis.go new file mode 100644 index 00000000..69ebe914 --- /dev/null +++ b/web/cache/redis.go @@ -0,0 +1,137 @@ +// Package cache provides Redis caching functionality for the 3x-ui web panel. +// It supports both embedded Redis (miniredis) and external Redis server. +package cache + +import ( + "context" + "fmt" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/mhsanaei/3x-ui/v2/logger" + "github.com/redis/go-redis/v9" +) + +var ( + client *redis.Client + miniRedis *miniredis.Miniredis + ctx = context.Background() + isEmbedded = true +) + +// InitRedis initializes Redis client. If redisAddr is empty, starts embedded Redis. +// If redisAddr is provided, connects to external Redis server. +func InitRedis(redisAddr string) error { + if redisAddr == "" { + // Use embedded Redis + mr, err := miniredis.Run() + if err != nil { + return fmt.Errorf("failed to start embedded Redis: %w", err) + } + miniRedis = mr + client = redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + isEmbedded = true + logger.Info("Embedded Redis started on", mr.Addr()) + } else { + // Use external Redis + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + Password: "", // Can be extended to support password + DB: 0, + }) + isEmbedded = false + + // Test connection + _, err := client.Ping(ctx).Result() + if err != nil { + return fmt.Errorf("failed to connect to Redis at %s: %w", redisAddr, err) + } + logger.Info("Connected to external Redis at", redisAddr) + } + + return nil +} + +// GetClient returns the Redis client instance. +func GetClient() *redis.Client { + return client +} + +// IsEmbedded returns true if using embedded Redis. +func IsEmbedded() bool { + return isEmbedded +} + +// Close closes the Redis connection and stops embedded Redis if running. +func Close() error { + if client != nil { + if err := client.Close(); err != nil { + return err + } + } + if miniRedis != nil { + miniRedis.Close() + } + return nil +} + +// Set stores a value in Redis with expiration. +func Set(key string, value interface{}, expiration time.Duration) error { + if client == nil { + return fmt.Errorf("Redis client not initialized") + } + return client.Set(ctx, key, value, expiration).Err() +} + +// Get retrieves a value from Redis. +func Get(key string) (string, error) { + if client == nil { + return "", fmt.Errorf("Redis client not initialized") + } + result, err := client.Get(ctx, key).Result() + if err == redis.Nil { + // Key doesn't exist - this is expected, not an error + return "", fmt.Errorf("redis: nil") + } + return result, err +} + +// Delete removes a key from Redis. +func Delete(key string) error { + if client == nil { + return fmt.Errorf("Redis client not initialized") + } + return client.Del(ctx, key).Err() +} + +// DeletePattern removes all keys matching a pattern. +func DeletePattern(pattern string) error { + if client == nil { + return fmt.Errorf("Redis client not initialized") + } + + iter := client.Scan(ctx, 0, pattern, 0).Iterator() + keys := make([]string, 0) + for iter.Next(ctx) { + keys = append(keys, iter.Val()) + } + if err := iter.Err(); err != nil { + return err + } + + if len(keys) > 0 { + return client.Del(ctx, keys...).Err() + } + return nil +} + +// Exists checks if a key exists in Redis. +func Exists(key string) (bool, error) { + if client == nil { + return false, fmt.Errorf("Redis client not initialized") + } + count, err := client.Exists(ctx, key).Result() + return count > 0, err +} diff --git a/web/cache/redisstore.go b/web/cache/redisstore.go new file mode 100644 index 00000000..5a811646 --- /dev/null +++ b/web/cache/redisstore.go @@ -0,0 +1,176 @@ +// Package cache provides Redis store for gin sessions. +package cache + +import ( + "bytes" + "context" + "encoding/base32" + "encoding/gob" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gin-contrib/sessions" + gorillasessions "github.com/gorilla/sessions" + "github.com/gorilla/securecookie" + "github.com/redis/go-redis/v9" +) + +const ( + defaultMaxAge = 86400 * 7 // 7 days +) + +// RedisStore stores sessions in Redis. +type RedisStore struct { + client *redis.Client + Codecs []securecookie.Codec + options *sessions.Options +} + +// NewRedisStore creates a new Redis store. +func NewRedisStore(client *redis.Client, keyPairs ...[]byte) *RedisStore { + rs := &RedisStore{ + client: client, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + options: &sessions.Options{ + Path: "/", + MaxAge: defaultMaxAge, + }, + } + return rs +} + +// Options sets the options for the store. +func (s *RedisStore) Options(opts sessions.Options) { + s.options = &opts +} + +// Get retrieves a session from Redis. +func (s *RedisStore) Get(r *http.Request, name string) (*gorillasessions.Session, error) { + return gorillasessions.GetRegistry(r).Get(s, name) +} + +// New creates a new session. +func (s *RedisStore) New(r *http.Request, name string) (*gorillasessions.Session, error) { + session := gorillasessions.NewSession(s, name) + session.Options = &gorillasessions.Options{ + Path: s.options.Path, + Domain: s.options.Domain, + MaxAge: s.options.MaxAge, + Secure: s.options.Secure, + HttpOnly: s.options.HttpOnly, + SameSite: s.options.SameSite, + } + session.IsNew = true + + // Try to load existing session from cookie + if c, errCookie := r.Cookie(name); errCookie == nil { + err := securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...) + if err == nil { + // Successfully decoded session ID, try to load from Redis + err = s.load(session) + if err == nil { + session.IsNew = false + } + // If load fails, continue with new session (session.IsNew = true) + } + // If decode fails (e.g., old cookie format), ignore and create new session + } + + return session, nil +} + +// Save saves a session to Redis. +func (s *RedisStore) Save(r *http.Request, w http.ResponseWriter, session *gorillasessions.Session) error { + // Delete if max age is < 0 + if session.Options.MaxAge < 0 { + if err := s.delete(session); err != nil { + return err + } + http.SetCookie(w, s.newCookie(session, "")) + return nil + } + + if session.ID == "" { + session.ID = strings.TrimRight( + base32.StdEncoding.EncodeToString( + securecookie.GenerateRandomKey(32), + ), "=") + } + + if err := s.save(session); err != nil { + return err + } + + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...) + if err != nil { + return err + } + + http.SetCookie(w, s.newCookie(session, encoded)) + return nil +} + +// newCookie creates a new HTTP cookie for the session. +func (s *RedisStore) newCookie(session *gorillasessions.Session, value string) *http.Cookie { + cookie := &http.Cookie{ + Name: session.Name(), + Value: value, + Path: session.Options.Path, + Domain: session.Options.Domain, + MaxAge: session.Options.MaxAge, + Secure: session.Options.Secure, + HttpOnly: session.Options.HttpOnly, + SameSite: session.Options.SameSite, + } + if session.Options.MaxAge > 0 { + cookie.Expires = time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second) + } + return cookie +} + +// save stores session data in Redis. +func (s *RedisStore) save(session *gorillasessions.Session) error { + // Use gob encoding to preserve types (especially for model.User) + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(session.Values); err != nil { + return fmt.Errorf("failed to encode session values: %w", err) + } + + maxAge := session.Options.MaxAge + if maxAge == 0 { + maxAge = s.options.MaxAge + } + + key := fmt.Sprintf("session:%s", session.ID) + return s.client.Set(context.Background(), key, buf.Bytes(), time.Duration(maxAge)*time.Second).Err() +} + +// load retrieves session data from Redis. +func (s *RedisStore) load(session *gorillasessions.Session) error { + key := fmt.Sprintf("session:%s", session.ID) + data, err := s.client.Get(context.Background(), key).Bytes() + if err == redis.Nil { + return fmt.Errorf("session not found") + } + if err != nil { + return err + } + + // Use gob decoding to preserve types (especially for model.User) + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + if err := dec.Decode(&session.Values); err != nil { + return fmt.Errorf("failed to decode session data: %w", err) + } + + return nil +} + +// delete removes session from Redis. +func (s *RedisStore) delete(session *gorillasessions.Session) error { + key := fmt.Sprintf("session:%s", session.ID) + return s.client.Del(context.Background(), key).Err() +} diff --git a/web/middleware/cache.go b/web/middleware/cache.go new file mode 100644 index 00000000..cf403333 --- /dev/null +++ b/web/middleware/cache.go @@ -0,0 +1,132 @@ +// Package middleware provides HTTP response caching middleware for the 3x-ui web panel. +package middleware + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/gin-gonic/gin" + "github.com/mhsanaei/3x-ui/v2/web/cache" +) + +// CacheMiddleware creates a middleware that caches HTTP responses. +// It caches GET requests based on the full URL path and query parameters. +func CacheMiddleware(ttl time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + // Only cache GET requests + if c.Request.Method != "GET" { + c.Next() + return + } + + // Generate cache key from request path and query + cacheKey := generateCacheKey(c.Request.URL.Path, c.Request.URL.RawQuery) + + // Try to get from cache + var cachedResponse map[string]interface{} + err := cache.GetJSON(cacheKey, &cachedResponse) + if err == nil { + // Cache hit - return cached response + c.JSON(200, cachedResponse) + c.Abort() + return + } + + // Cache miss - continue to handler and capture response + c.Next() + + // Only cache successful responses (status 200) + if c.Writer.Status() == 200 { + // Try to capture the response body + // Note: This is a simplified version - in production you might want to use + // a response writer wrapper to capture the actual response body + // For now, we'll let the service layer handle caching + } + } +} + +// CacheResponse caches a JSON response with the given key and TTL. +func CacheResponse(key string, data interface{}, ttl time.Duration) error { + return cache.SetJSON(key, data, ttl) +} + +// GetCachedResponse retrieves a cached JSON response. +func GetCachedResponse(key string, dest interface{}) error { + return cache.GetJSON(key, dest) +} + +// InvalidateCacheKey invalidates a specific cache key. +func InvalidateCacheKey(key string) error { + return cache.Delete(key) +} + +// generateCacheKey creates a cache key from path and query string. +func generateCacheKey(path, query string) string { + key := fmt.Sprintf("http:%s", path) + if query != "" { + hash := sha256.Sum256([]byte(query)) + key += ":" + hex.EncodeToString(hash[:])[:16] + } + return key +} + +// UserCacheMiddleware creates a middleware that caches responses per user. +// It includes the user ID in the cache key to ensure user-specific caching. +func UserCacheMiddleware(ttl time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + // Only cache GET requests + if c.Request.Method != "GET" { + c.Next() + return + } + + // Get user ID from session + userID := getUserIDFromContext(c) + if userID == 0 { + c.Next() + return + } + + // Generate cache key with user ID + cacheKey := generateUserCacheKey(c.Request.URL.Path, c.Request.URL.RawQuery, userID) + + // Try to get from cache + var cachedResponse map[string]interface{} + err := cache.GetJSON(cacheKey, &cachedResponse) + if err == nil { + // Cache hit - return cached response + c.JSON(200, cachedResponse) + c.Abort() + return + } + + // Cache miss - continue to handler + c.Next() + } +} + +// generateUserCacheKey creates a cache key with user ID. +func generateUserCacheKey(path, query string, userID int) string { + key := fmt.Sprintf("http:user:%d:%s", userID, path) + if query != "" { + hash := sha256.Sum256([]byte(query)) + key += ":" + hex.EncodeToString(hash[:])[:16] + } + return key +} + +// getUserIDFromContext extracts user ID from gin context. +// This is a helper function - you may need to adjust based on your session implementation. +func getUserIDFromContext(c *gin.Context) int { + // Try to get from session + if user, exists := c.Get("user"); exists { + if userMap, ok := user.(map[string]interface{}); ok { + if id, ok := userMap["id"].(int); ok { + return id + } + } + } + return 0 +} diff --git a/web/service/client.go b/web/service/client.go index 5e9cf3c9..1571a1b6 100644 --- a/web/service/client.go +++ b/web/service/client.go @@ -13,6 +13,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/logger" "github.com/mhsanaei/3x-ui/v2/util/common" "github.com/mhsanaei/3x-ui/v2/util/random" + "github.com/mhsanaei/3x-ui/v2/web/cache" "github.com/mhsanaei/3x-ui/v2/xray" "gorm.io/gorm" @@ -23,59 +24,68 @@ type ClientService struct{} // GetClients retrieves all clients for a specific user. // Also loads traffic statistics and last online time for each client. +// Results are cached in Redis for 30 seconds. func (s *ClientService) GetClients(userId int) ([]*model.ClientEntity, error) { - db := database.GetDB() + key := fmt.Sprintf("%s%d", cache.KeyClientsPrefix, userId) var clients []*model.ClientEntity - err := db.Where("user_id = ?", userId).Find(&clients).Error - if err != nil { - return nil, err - } - - // Load inbound assignments, traffic statistics, and HWIDs for each client - for _, client := range clients { - // Load inbound assignments - inboundIds, err := s.GetInboundIdsForClient(client.Id) - if err == nil { - client.InboundIds = inboundIds + + err := cache.GetOrSet(key, &clients, cache.TTLClients, func() (interface{}, error) { + // Cache miss - fetch from database + db := database.GetDB() + var result []*model.ClientEntity + err := db.Where("user_id = ?", userId).Find(&result).Error + if err != nil { + return nil, err } - // Traffic statistics are now stored directly in ClientEntity table - // No need to load from client_traffics - fields are already loaded from DB - - // Check if client exceeded limits and update status if needed (but keep Enable = true) - now := time.Now().Unix() * 1000 - totalUsed := client.Up + client.Down - trafficLimit := int64(client.TotalGB * 1024 * 1024 * 1024) - trafficExceeded := client.TotalGB > 0 && totalUsed >= trafficLimit - timeExpired := client.ExpiryTime > 0 && client.ExpiryTime <= now - - // Update status if expired, but don't change Enable - if trafficExceeded || timeExpired { - status := "expired_traffic" - if timeExpired { - status = "expired_time" + // Load inbound assignments, traffic statistics, and HWIDs for each client + for _, client := range result { + // Load inbound assignments + inboundIds, err := s.GetInboundIdsForClient(client.Id) + if err == nil { + client.InboundIds = inboundIds } - // Only update if status changed - if client.Status != status { - client.Status = status - err = db.Model(&model.ClientEntity{}).Where("id = ?", client.Id).Update("status", status).Error - if err != nil { - logger.Warningf("Failed to update status for client %s: %v", client.Email, err) + + // Traffic statistics are now stored directly in ClientEntity table + // No need to load from client_traffics - fields are already loaded from DB + + // Check if client exceeded limits and update status if needed (but keep Enable = true) + now := time.Now().Unix() * 1000 + totalUsed := client.Up + client.Down + trafficLimit := int64(client.TotalGB * 1024 * 1024 * 1024) + trafficExceeded := client.TotalGB > 0 && totalUsed >= trafficLimit + timeExpired := client.ExpiryTime > 0 && client.ExpiryTime <= now + + // Update status if expired, but don't change Enable + if trafficExceeded || timeExpired { + status := "expired_traffic" + if timeExpired { + status = "expired_time" + } + // Only update if status changed + if client.Status != status { + client.Status = status + err = db.Model(&model.ClientEntity{}).Where("id = ?", client.Id).Update("status", status).Error + if err != nil { + logger.Warningf("Failed to update status for client %s: %v", client.Email, err) + } } } + + // Load HWIDs for this client + hwidService := ClientHWIDService{} + hwids, err := hwidService.GetHWIDsForClient(client.Id) + if err == nil { + client.HWIDs = hwids + } else { + logger.Warningf("Failed to load HWIDs for client %d: %v", client.Id, err) + } } - // Load HWIDs for this client - hwidService := ClientHWIDService{} - hwids, err := hwidService.GetHWIDsForClient(client.Id) - if err == nil { - client.HWIDs = hwids - } else { - logger.Warningf("Failed to load HWIDs for client %d: %v", client.Id, err) - } - } - - return clients, nil + return result, nil + }) + + return clients, err } // GetClient retrieves a client by ID. @@ -219,6 +229,9 @@ func (s *ClientService) AddClient(userId int, client *model.ClientEntity) (bool, return false, err } + // Invalidate cache for this user's clients + cache.InvalidateClients(userId) + // Now update Settings for all assigned inbounds // This is done AFTER committing the client transaction to avoid nested transactions and database locks needRestart := false @@ -434,6 +447,9 @@ func (s *ClientService) UpdateClient(userId int, client *model.ClientEntity) (bo return false, err } + // Invalidate cache for this user's clients + cache.InvalidateClients(userId) + // Now update Settings for all affected inbounds (old + new) // This is needed even if InboundIds wasn't changed, because client data (UUID, password, etc.) might have changed // We do this AFTER committing the client transaction to avoid nested transactions and database locks diff --git a/web/service/inbound.go b/web/service/inbound.go index a3d2e58f..aa7fc4ed 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -15,6 +15,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/database/model" "github.com/mhsanaei/3x-ui/v2/logger" "github.com/mhsanaei/3x-ui/v2/util/common" + "github.com/mhsanaei/3x-ui/v2/web/cache" "github.com/mhsanaei/3x-ui/v2/xray" "gorm.io/gorm" @@ -95,51 +96,60 @@ func (s *InboundService) updateInboundWithRetry(inbound *model.Inbound) (*model. // GetInbounds retrieves all inbounds for a specific user. // Returns a slice of inbound models with their associated client statistics. +// Results are cached in Redis for 30 seconds. func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) { - db := database.GetDB() + key := fmt.Sprintf("%s%d", cache.KeyInboundsPrefix, userId) var inbounds []*model.Inbound - err := db.Model(model.Inbound{}).Preload("ClientStats").Where("user_id = ?", userId).Find(&inbounds).Error - if err != nil && err != gorm.ErrRecordNotFound { - return nil, err - } - // Enrich with node assignments - nodeService := NodeService{} - for _, inbound := range inbounds { - // Load all nodes for this inbound - nodes, err := nodeService.GetNodesForInbound(inbound.Id) - if err == nil && len(nodes) > 0 { - nodeIds := make([]int, len(nodes)) - for i, node := range nodes { - nodeIds[i] = node.Id - } - inbound.NodeIds = nodeIds - // Don't set nodeId - it's deprecated and causes confusion - // nodeId is only for backward compatibility when receiving data from old clients - } else { - // Ensure empty array if no nodes assigned - inbound.NodeIds = []int{} + err := cache.GetOrSet(key, &inbounds, cache.TTLInbounds, func() (interface{}, error) { + // Cache miss - fetch from database + db := database.GetDB() + var result []*model.Inbound + err := db.Model(model.Inbound{}).Preload("ClientStats").Where("user_id = ?", userId).Find(&result).Error + if err != nil && err != gorm.ErrRecordNotFound { + return nil, err } - // Enrich client stats with UUID/SubId from inbound settings - clients, _ := s.GetClients(inbound) - if len(clients) == 0 || len(inbound.ClientStats) == 0 { - continue - } - // Build a map email -> client - cMap := make(map[string]model.Client, len(clients)) - for _, c := range clients { - cMap[strings.ToLower(c.Email)] = c - } - for i := range inbound.ClientStats { - email := strings.ToLower(inbound.ClientStats[i].Email) - if c, ok := cMap[email]; ok { - inbound.ClientStats[i].UUID = c.ID - inbound.ClientStats[i].SubId = c.SubID + // Enrich with node assignments + nodeService := NodeService{} + for _, inbound := range result { + // Load all nodes for this inbound + nodes, err := nodeService.GetNodesForInbound(inbound.Id) + if err == nil && len(nodes) > 0 { + nodeIds := make([]int, len(nodes)) + for i, node := range nodes { + nodeIds[i] = node.Id + } + inbound.NodeIds = nodeIds + // Don't set nodeId - it's deprecated and causes confusion + // nodeId is only for backward compatibility when receiving data from old clients + } else { + // Ensure empty array if no nodes assigned + inbound.NodeIds = []int{} + } + + // Enrich client stats with UUID/SubId from inbound settings + clients, _ := s.GetClients(inbound) + if len(clients) == 0 || len(inbound.ClientStats) == 0 { + continue + } + // Build a map email -> client + cMap := make(map[string]model.Client, len(clients)) + for _, c := range clients { + cMap[strings.ToLower(c.Email)] = c + } + for i := range inbound.ClientStats { + email := strings.ToLower(inbound.ClientStats[i].Email) + if c, ok := cMap[email]; ok { + inbound.ClientStats[i].UUID = c.ID + inbound.ClientStats[i].SubId = c.SubID + } } } - } - return inbounds, nil + return result, nil + }) + + return inbounds, err } // GetAllInbounds retrieves all inbounds from the database. @@ -440,6 +450,11 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo return inbound, false, err } + // Invalidate cache for this user's inbounds + if inbound.UserId > 0 { + cache.InvalidateInbounds(inbound.UserId) + } + // Note: ClientStats are no longer managed here - clients are managed through ClientEntity // Traffic is stored directly in ClientEntity table @@ -503,6 +518,7 @@ func (s *InboundService) DelInbound(id int) (bool, error) { if err != nil { return false, err } + userId := inbound.UserId clients, err := s.GetClients(inbound) if err != nil { return false, err @@ -514,7 +530,12 @@ func (s *InboundService) DelInbound(id int) (bool, error) { } } - return needRestart, db.Delete(model.Inbound{}, id).Error + err = db.Delete(model.Inbound{}, id).Error + if err == nil && userId > 0 { + // Invalidate cache for this user's inbounds + cache.InvalidateInbounds(userId) + } + return needRestart, err } func (s *InboundService) GetInbound(id int) (*model.Inbound, error) { @@ -690,7 +711,14 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, logger.Debug("Inbound is disabled, not adding to Xray:", tag) } - return inbound, needRestart, tx.Save(oldInbound).Error + err = tx.Save(oldInbound).Error + if err == nil { + // Invalidate cache for this user's inbounds + if oldInbound.UserId > 0 { + cache.InvalidateInbounds(oldInbound.UserId) + } + } + return inbound, needRestart, err } // updateClientTraffics is removed - clients are now managed through ClientEntity diff --git a/web/service/setting.go b/web/service/setting.go index 4d64dcd3..6cda3b58 100644 --- a/web/service/setting.go +++ b/web/service/setting.go @@ -16,6 +16,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/util/common" "github.com/mhsanaei/3x-ui/v2/util/random" "github.com/mhsanaei/3x-ui/v2/util/reflect_util" + "github.com/mhsanaei/3x-ui/v2/web/cache" "github.com/mhsanaei/3x-ui/v2/web/entity" "github.com/mhsanaei/3x-ui/v2/xray" ) @@ -114,78 +115,85 @@ func (s *SettingService) GetDefaultJsonConfig() (any, error) { } func (s *SettingService) GetAllSetting() (*entity.AllSetting, error) { - db := database.GetDB() - settings := make([]*model.Setting, 0) - err := db.Model(model.Setting{}).Not("key = ?", "xrayTemplateConfig").Find(&settings).Error - if err != nil { - return nil, err - } - allSetting := &entity.AllSetting{} - t := reflect.TypeOf(allSetting).Elem() - v := reflect.ValueOf(allSetting).Elem() - fields := reflect_util.GetFields(t) + var allSetting *entity.AllSetting + + err := cache.GetOrSet(cache.KeySettingsAll, &allSetting, cache.TTLSettings, func() (interface{}, error) { + // Cache miss - fetch from database + db := database.GetDB() + settings := make([]*model.Setting, 0) + err := db.Model(model.Setting{}).Not("key = ?", "xrayTemplateConfig").Find(&settings).Error + if err != nil { + return nil, err + } + result := &entity.AllSetting{} + t := reflect.TypeOf(result).Elem() + v := reflect.ValueOf(result).Elem() + fields := reflect_util.GetFields(t) - setSetting := func(key, value string) (err error) { - defer func() { - panicErr := recover() - if panicErr != nil { - err = errors.New(fmt.Sprint(panicErr)) - } - }() + setSetting := func(key, value string) (err error) { + defer func() { + panicErr := recover() + if panicErr != nil { + err = errors.New(fmt.Sprint(panicErr)) + } + }() - var found bool - var field reflect.StructField - for _, f := range fields { - if f.Tag.Get("json") == key { - field = f - found = true - break + var found bool + var field reflect.StructField + for _, f := range fields { + if f.Tag.Get("json") == key { + field = f + found = true + break + } } + + if !found { + // Some settings are automatically generated, no need to return to the front end to modify the user + return nil + } + + fieldV := v.FieldByName(field.Name) + switch t := fieldV.Interface().(type) { + case int: + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(n) + case string: + fieldV.SetString(value) + case bool: + fieldV.SetBool(value == "true") + default: + return common.NewErrorf("unknown field %v type %v", key, t) + } + return } - if !found { - // Some settings are automatically generated, no need to return to the front end to modify the user - return nil - } - - fieldV := v.FieldByName(field.Name) - switch t := fieldV.Interface().(type) { - case int: - n, err := strconv.ParseInt(value, 10, 64) + keyMap := map[string]bool{} + for _, setting := range settings { + err := setSetting(setting.Key, setting.Value) if err != nil { - return err + return nil, err } - fieldV.SetInt(n) - case string: - fieldV.SetString(value) - case bool: - fieldV.SetBool(value == "true") - default: - return common.NewErrorf("unknown field %v type %v", key, t) + keyMap[setting.Key] = true } - return - } - keyMap := map[string]bool{} - for _, setting := range settings { - err := setSetting(setting.Key, setting.Value) - if err != nil { - return nil, err + for key, value := range defaultValueMap { + if keyMap[key] { + continue + } + err := setSetting(key, value) + if err != nil { + return nil, err + } } - keyMap[setting.Key] = true - } - for key, value := range defaultValueMap { - if keyMap[key] { - continue - } - err := setSetting(key, value) - if err != nil { - return nil, err - } - } - - return allSetting, nil + return result, nil + }) + + return allSetting, err } func (s *SettingService) ResetSettings() error { @@ -199,29 +207,54 @@ func (s *SettingService) ResetSettings() error { } func (s *SettingService) getSetting(key string) (*model.Setting, error) { - db := database.GetDB() - setting := &model.Setting{} - err := db.Model(model.Setting{}).Where("key = ?", key).First(setting).Error - if err != nil { - return nil, err - } - return setting, nil + cacheKey := cache.KeySettingPrefix + key + var setting *model.Setting + + err := cache.GetOrSet(cacheKey, &setting, cache.TTLSetting, func() (interface{}, error) { + // Cache miss - fetch from database + db := database.GetDB() + result := &model.Setting{} + err := db.Model(model.Setting{}).Where("key = ?", key).First(result).Error + if err != nil { + return nil, err + } + return result, nil + }) + + return setting, err } func (s *SettingService) saveSetting(key string, value string) error { setting, err := s.getSetting(key) db := database.GetDB() if database.IsNotFound(err) { - return db.Create(&model.Setting{ + err = db.Create(&model.Setting{ Key: key, Value: value, }).Error } else if err != nil { return err + } else { + setting.Key = key + setting.Value = value + err = db.Save(setting).Error } - setting.Key = key - setting.Value = value - return db.Save(setting).Error + + if err == nil { + // Invalidate cache for this specific setting + cache.InvalidateSetting(key) + // Invalidate all settings cache only when a setting is actually changed + // This ensures consistency while avoiding unnecessary cache misses + cache.Delete(cache.KeySettingsAll) + // Also invalidate default settings cache (they depend on individual settings) + cache.DeletePattern("defaultSettings:*") + // Invalidate computed settings that depend on this setting + if key == "multiNodeMode" { + cache.Delete("computed:ipLimitEnable") + } + } + + return err } func (s *SettingService) getString(key string) (string, error) { @@ -568,18 +601,26 @@ func (s *SettingService) SetExternalTrafficInformURI(InformURI string) error { } func (s *SettingService) GetIpLimitEnable() (bool, error) { - // Check if multi-node mode is enabled - multiMode, err := s.GetMultiNodeMode() - if err == nil && multiMode { - // In multi-node mode, IP limiting is handled by nodes - return false, nil - } + // Cache key for this computed setting + cacheKey := "computed:ipLimitEnable" + var result bool - accessLogPath, err := xray.GetAccessLogPath() - if err != nil { - return false, err - } - return (accessLogPath != "none" && accessLogPath != ""), nil + err := cache.GetOrSet(cacheKey, &result, cache.TTLSetting, func() (interface{}, error) { + // Check if multi-node mode is enabled + multiMode, err := s.GetMultiNodeMode() + if err == nil && multiMode { + // In multi-node mode, IP limiting is handled by nodes + return false, nil + } + + accessLogPath, err := xray.GetAccessLogPath() + if err != nil { + return false, err + } + return (accessLogPath != "none" && accessLogPath != ""), nil + }) + + return result, err } // LDAP exported getters @@ -738,32 +779,44 @@ func (s *SettingService) GetDefaultXrayConfig() (any, error) { } func (s *SettingService) GetDefaultSettings(host string) (any, error) { - type settingFunc func() (any, error) - settings := map[string]settingFunc{ - "expireDiff": func() (any, error) { return s.GetExpireDiff() }, - "trafficDiff": func() (any, error) { return s.GetTrafficDiff() }, - "pageSize": func() (any, error) { return s.GetPageSize() }, - "defaultCert": func() (any, error) { return s.GetCertFile() }, - "defaultKey": func() (any, error) { return s.GetKeyFile() }, - "tgBotEnable": func() (any, error) { return s.GetTgbotEnabled() }, - "subEnable": func() (any, error) { return s.GetSubEnable() }, - "subJsonEnable": func() (any, error) { return s.GetSubJsonEnable() }, - "subTitle": func() (any, error) { return s.GetSubTitle() }, - "subURI": func() (any, error) { return s.GetSubURI() }, - "subJsonURI": func() (any, error) { return s.GetSubJsonURI() }, - "remarkModel": func() (any, error) { return s.GetRemarkModel() }, - "datepicker": func() (any, error) { return s.GetDatepicker() }, - "ipLimitEnable": func() (any, error) { return s.GetIpLimitEnable() }, - } - - result := make(map[string]any) - - for key, fn := range settings { - value, err := fn() - if err != nil { - return "", err + // Cache key includes host to support multi-domain setups + cacheKey := fmt.Sprintf("defaultSettings:%s", host) + var result map[string]any + + err := cache.GetOrSet(cacheKey, &result, cache.TTLSettings, func() (interface{}, error) { + // Cache miss - compute default settings + type settingFunc func() (any, error) + settings := map[string]settingFunc{ + "expireDiff": func() (any, error) { return s.GetExpireDiff() }, + "trafficDiff": func() (any, error) { return s.GetTrafficDiff() }, + "pageSize": func() (any, error) { return s.GetPageSize() }, + "defaultCert": func() (any, error) { return s.GetCertFile() }, + "defaultKey": func() (any, error) { return s.GetKeyFile() }, + "tgBotEnable": func() (any, error) { return s.GetTgbotEnabled() }, + "subEnable": func() (any, error) { return s.GetSubEnable() }, + "subJsonEnable": func() (any, error) { return s.GetSubJsonEnable() }, + "subTitle": func() (any, error) { return s.GetSubTitle() }, + "subURI": func() (any, error) { return s.GetSubURI() }, + "subJsonURI": func() (any, error) { return s.GetSubJsonURI() }, + "remarkModel": func() (any, error) { return s.GetRemarkModel() }, + "datepicker": func() (any, error) { return s.GetDatepicker() }, + "ipLimitEnable": func() (any, error) { return s.GetIpLimitEnable() }, } - result[key] = value + + res := make(map[string]any) + + for key, fn := range settings { + value, err := fn() + if err != nil { + return nil, err + } + res[key] = value + } + return res, nil + }) + + if err != nil { + return nil, err } subEnable := result["subEnable"].(bool) diff --git a/web/web.go b/web/web.go index 5da9b298..9283e3df 100644 --- a/web/web.go +++ b/web/web.go @@ -31,6 +31,7 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" + "github.com/mhsanaei/3x-ui/v2/web/cache" "github.com/robfig/cron/v3" ) @@ -203,7 +204,19 @@ func (s *Server) initRouter() (*gin.Engine, error) { engine.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{basePath + "panel/api/"}))) assetsBasePath := basePath + "assets/" - store := cookie.NewStore(secret) + // Use Redis store for sessions if available, otherwise fallback to cookie store + var store sessions.Store + redisClient := cache.GetClient() + if redisClient != nil { + // Use Redis store + store = cache.NewRedisStore(redisClient, []byte(secret)) + logger.Info("Using Redis store for sessions") + } else { + // Fallback to cookie store + store = cookie.NewStore(secret) + logger.Info("Using cookie store for sessions (Redis not available)") + } + // Configure default session cookie options, including expiration (MaxAge) if sessionMaxAge, err := s.settingService.GetSessionMaxAge(); err == nil { store.Options(sessions.Options{ @@ -220,7 +233,12 @@ func (s *Server) initRouter() (*gin.Engine, error) { engine.Use(func(c *gin.Context) { uri := c.Request.RequestURI if strings.HasPrefix(uri, assetsBasePath) { - c.Header("Cache-Control", "max-age=31536000") + // Cache static assets for 1 year with immutable flag + c.Header("Cache-Control", "max-age=31536000, public, immutable") + } else if strings.HasPrefix(uri, basePath+"panel/api/") && c.Request.Method == "GET" { + // For API GET requests, use no-cache but allow conditional requests + // This enables browser caching with validation + c.Header("Cache-Control", "no-cache, must-revalidate") } }) @@ -498,3 +516,13 @@ func (s *Server) GetCron() *cron.Cron { func (s *Server) GetWSHub() any { return s.wsHub } + +// InitRedisCache initializes Redis cache. If redisAddr is empty, uses embedded Redis. +func InitRedisCache(redisAddr string) error { + return cache.InitRedis(redisAddr) +} + +// CloseRedisCache closes Redis cache connection. +func CloseRedisCache() error { + return cache.Close() +} diff --git a/web/websocket/hub.go b/web/websocket/hub.go index 8aa5903c..a3a5bc16 100644 --- a/web/websocket/hub.go +++ b/web/websocket/hub.go @@ -2,6 +2,7 @@ package websocket import ( + "bytes" "context" "encoding/json" "runtime" @@ -62,6 +63,15 @@ type Hub struct { // Worker pool for parallel broadcasting workerPoolSize int broadcastWg sync.WaitGroup + + // Cache for last serialized messages to avoid re-serialization + messageCache map[MessageType][]byte + cacheMu sync.RWMutex + + // Throttling for frequent updates + throttleMap map[MessageType]time.Time + throttleMu sync.Mutex + throttleInterval time.Duration } // NewHub creates a new WebSocket hub @@ -85,6 +95,9 @@ func NewHub() *Hub { ctx: ctx, cancel: cancel, workerPoolSize: workerPoolSize, + messageCache: make(map[MessageType][]byte), + throttleMap: make(map[MessageType]time.Time), + throttleInterval: 100 * time.Millisecond, // Throttle updates to max 10 per second per type } } @@ -259,18 +272,37 @@ func (h *Hub) Broadcast(messageType MessageType, payload any) { return } + // Throttle frequent updates (except for critical messages) + if messageType == MessageTypeInbounds || messageType == MessageTypeTraffic { + h.throttleMu.Lock() + lastTime, exists := h.throttleMap[messageType] + if exists && time.Since(lastTime) < h.throttleInterval { + h.throttleMu.Unlock() + return // Skip this update, too soon + } + h.throttleMap[messageType] = time.Now() + h.throttleMu.Unlock() + } + + // Use buffer pool for JSON encoding to reduce allocations + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) // Faster encoding, no HTML escaping needed + msg := Message{ Type: messageType, Payload: payload, Time: getCurrentTimestamp(), } - data, err := json.Marshal(msg) - if err != nil { + if err := enc.Encode(msg); err != nil { logger.Error("Failed to marshal WebSocket message:", err) return } + // Remove trailing newline from Encode + data := bytes.TrimRight(buf.Bytes(), "\n") + // Limit message size to prevent memory issues const maxMessageSize = 1024 * 1024 // 1MB if len(data) > maxMessageSize { @@ -278,6 +310,14 @@ func (h *Hub) Broadcast(messageType MessageType, payload any) { return } + // Cache the serialized message for potential reuse + // Make a copy to avoid issues with buffer reuse + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + h.cacheMu.Lock() + h.messageCache[messageType] = dataCopy + h.cacheMu.Unlock() + // Non-blocking send with timeout to prevent delays select { case h.broadcast <- data: @@ -298,18 +338,25 @@ func (h *Hub) BroadcastToTopic(messageType MessageType, payload any) { return } + // Use buffer pool for JSON encoding to reduce allocations + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) // Faster encoding, no HTML escaping needed + msg := Message{ Type: messageType, Payload: payload, Time: getCurrentTimestamp(), } - data, err := json.Marshal(msg) - if err != nil { + if err := enc.Encode(msg); err != nil { logger.Error("Failed to marshal WebSocket message:", err) return } + // Remove trailing newline from Encode + data := bytes.TrimRight(buf.Bytes(), "\n") + // Limit message size to prevent memory issues const maxMessageSize = 1024 * 1024 // 1MB if len(data) > maxMessageSize { @@ -317,6 +364,14 @@ func (h *Hub) BroadcastToTopic(messageType MessageType, payload any) { return } + // Cache the serialized message for potential reuse + // Make a copy to avoid issues with buffer reuse + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + h.cacheMu.Lock() + h.messageCache[messageType] = dataCopy + h.cacheMu.Unlock() + h.mu.RLock() // Filter clients by topics and quickly release lock subscribedClients := make([]*Client, 0)