3x-ui/web/middleware/ipfilter.go

152 lines
3.6 KiB
Go
Raw Normal View History

package middleware
import (
"fmt"
"net"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/mhsanaei/3x-ui/v2/logger"
redisutil "github.com/mhsanaei/3x-ui/v2/util/redis"
)
// IPFilterConfig configures IP filtering
type IPFilterConfig struct {
WhitelistEnabled bool
BlacklistEnabled bool
GeoIPEnabled bool
BlockedCountries []string
SkipPaths []string // Paths to skip IP filtering
}
// shouldSkip checks if path should be skipped
func (config IPFilterConfig) shouldSkip(path string) bool {
for _, skipPath := range config.SkipPaths {
if len(path) >= len(skipPath) && path[:len(skipPath)] == skipPath {
return true
}
}
return false
}
// IPFilterMiddleware creates IP filtering middleware
func IPFilterMiddleware(config IPFilterConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip IP filtering for certain paths
if config.shouldSkip(c.Request.URL.Path) {
c.Next()
return
}
ip := c.ClientIP()
// Validate IP format
if !ValidateIP(ip) {
logger.Warningf("Invalid IP format: %s", ip)
c.Next()
return
}
// Check blacklist first
if config.BlacklistEnabled {
isBlocked, err := redisutil.SIsMember("ip:blacklist", ip)
if err == nil && isBlocked {
logger.Warningf("Blocked IP attempted access: %s", ip)
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"msg": "Access denied",
})
c.Abort()
return
}
}
// Check whitelist if enabled
if config.WhitelistEnabled {
isWhitelisted, err := redisutil.SIsMember("ip:whitelist", ip)
if err == nil && !isWhitelisted {
logger.Warningf("Non-whitelisted IP attempted access: %s", ip)
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"msg": "Access denied",
})
c.Abort()
return
}
}
// Check GeoIP blocking
if config.GeoIPEnabled && len(config.BlockedCountries) > 0 {
country, err := getCountryFromIP(ip)
if err == nil && country != "" {
for _, blockedCountry := range config.BlockedCountries {
if strings.EqualFold(country, blockedCountry) {
logger.Warningf("Blocked country attempted access: %s from %s", country, ip)
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"msg": "Access denied",
})
c.Abort()
return
}
}
}
}
c.Next()
}
}
// getCountryFromIP gets country code from IP (simplified version)
// In production, use MaxMind GeoIP2 database
func getCountryFromIP(ip string) (string, error) {
// Check cache first
cacheKey := "geoip:" + ip
country, err := redisutil.Get(cacheKey)
if err == nil && country != "" {
return country, nil
}
// For now, return empty (will be implemented with MaxMind)
// This is a placeholder
return "", nil
}
// AddToBlacklist adds IP to blacklist
func AddToBlacklist(ip string) error {
if !ValidateIP(ip) {
return fmt.Errorf("invalid IP address: %s", ip)
}
return redisutil.SAdd("ip:blacklist", ip)
}
// RemoveFromBlacklist removes IP from blacklist
func RemoveFromBlacklist(ip string) error {
return redisutil.SRem("ip:blacklist", ip)
}
// AddToWhitelist adds IP to whitelist
func AddToWhitelist(ip string) error {
if !ValidateIP(ip) {
return fmt.Errorf("invalid IP address: %s", ip)
}
return redisutil.SAdd("ip:whitelist", ip)
}
// RemoveFromWhitelist removes IP from whitelist
func RemoveFromWhitelist(ip string) error {
return redisutil.SRem("ip:whitelist", ip)
}
// IsIPBlocked checks if IP is blocked
func IsIPBlocked(ip string) (bool, error) {
return redisutil.SIsMember("ip:blacklist", ip)
}
// ValidateIP validates IP address format
func ValidateIP(ip string) bool {
parsed := net.ParseIP(ip)
return parsed != nil
}