mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2025-12-23 06:42:41 +00:00
96 lines
2.4 KiB
Go
96 lines
2.4 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"strconv"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
"github.com/mhsanaei/3x-ui/v2/logger"
|
||
|
|
redisutil "github.com/mhsanaei/3x-ui/v2/util/redis"
|
||
|
|
)
|
||
|
|
|
||
|
|
// RateLimitConfig configures rate limiting
|
||
|
|
type RateLimitConfig struct {
|
||
|
|
RequestsPerMinute int
|
||
|
|
BurstSize int
|
||
|
|
KeyFunc func(c *gin.Context) string
|
||
|
|
SkipPaths []string // Paths to skip rate limiting
|
||
|
|
}
|
||
|
|
|
||
|
|
// DefaultRateLimitConfig returns default rate limit config
|
||
|
|
func DefaultRateLimitConfig() RateLimitConfig {
|
||
|
|
return RateLimitConfig{
|
||
|
|
RequestsPerMinute: 60,
|
||
|
|
BurstSize: 10,
|
||
|
|
KeyFunc: func(c *gin.Context) string {
|
||
|
|
return c.ClientIP()
|
||
|
|
},
|
||
|
|
SkipPaths: []string{"/assets/", "/favicon.ico"},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// shouldSkip checks if path should be skipped
|
||
|
|
func (config RateLimitConfig) shouldSkip(path string) bool {
|
||
|
|
for _, skipPath := range config.SkipPaths {
|
||
|
|
if len(path) >= len(skipPath) && path[:len(skipPath)] == skipPath {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
// RateLimitMiddleware creates rate limiting middleware
|
||
|
|
func RateLimitMiddleware(config RateLimitConfig) gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// Skip rate limiting for certain paths
|
||
|
|
if config.shouldSkip(c.Request.URL.Path) {
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
key := config.KeyFunc(c)
|
||
|
|
rateLimitKey := "ratelimit:" + key + ":" + c.Request.URL.Path
|
||
|
|
|
||
|
|
// Get current count
|
||
|
|
countStr, err := redisutil.Get(rateLimitKey)
|
||
|
|
var count int
|
||
|
|
if err != nil {
|
||
|
|
// Key doesn't exist, start with 0
|
||
|
|
count = 0
|
||
|
|
} else {
|
||
|
|
count, _ = strconv.Atoi(countStr)
|
||
|
|
}
|
||
|
|
|
||
|
|
if count >= config.RequestsPerMinute {
|
||
|
|
logger.Warningf("Rate limit exceeded for %s on %s (count: %d)", key, c.Request.URL.Path, count)
|
||
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
|
|
"success": false,
|
||
|
|
"msg": "Rate limit exceeded. Please try again later.",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Increment counter
|
||
|
|
newCount, err := redisutil.Incr(rateLimitKey)
|
||
|
|
if err != nil {
|
||
|
|
logger.Warning("Rate limit increment failed:", err)
|
||
|
|
c.Next()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Set expiration on first request
|
||
|
|
if newCount == 1 {
|
||
|
|
redisutil.Expire(rateLimitKey, time.Minute)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Set rate limit headers
|
||
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(config.RequestsPerMinute))
|
||
|
|
c.Header("X-RateLimit-Remaining", strconv.Itoa(config.RequestsPerMinute-int(newCount)))
|
||
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|