3x-ui/web/middleware/middleware_test.go
Sora39831 09f84782b0 test: add unit tests for critical modules and fix flaky redirect middleware
- Add tests for config, database, model, util/common, util/crypto,
  util/random, web/middleware, web/service, and xray packages
- Fix redirect middleware using slice instead of map to guarantee
  deterministic longest-prefix-first matching order
2026-04-03 08:44:51 +08:00

320 lines
7.8 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestRedirectMiddleware_XUIToPanel(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/*path", func(c *gin.Context) {
c.String(http.StatusOK, "panel")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/xui/settings", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/panel/settings" {
t.Errorf("expected redirect to /panel/settings, got %q", loc)
}
}
func TestRedirectMiddleware_XUIAPIToPanelAPI(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/api/*path", func(c *gin.Context) {
c.String(http.StatusOK, "api")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/xui/API/inbounds", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/panel/api/inbounds" {
t.Errorf("expected redirect to /panel/api/inbounds, got %q", loc)
}
}
func TestRedirectMiddleware_PanelAPICase(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/api/*path", func(c *gin.Context) {
c.String(http.StatusOK, "api")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/panel/API/list", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/panel/api/list" {
t.Errorf("expected redirect to /panel/api/list, got %q", loc)
}
}
func TestRedirectMiddleware_NoRedirect(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/"))
r.GET("/panel/settings", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/panel/settings", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestRedirectMiddleware_WithBasePath(t *testing.T) {
r := gin.New()
r.Use(RedirectMiddleware("/base/"))
r.GET("/base/panel/*path", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/base/xui/settings", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusMovedPermanently {
t.Errorf("expected 301, got %d", w.Code)
}
loc := w.Header().Get("Location")
if loc != "/base/panel/settings" {
t.Errorf("expected redirect to /base/panel/settings, got %q", loc)
}
}
func TestDomainValidatorMiddleware_MatchingDomain(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "example.com"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestDomainValidatorMiddleware_MatchingDomainWithPort(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "example.com:8443"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200 for matching domain with port, got %d", w.Code)
}
}
func TestDomainValidatorMiddleware_NonMatchingDomain(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "evil.com"
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403, got %d", w.Code)
}
}
func TestDomainValidatorMiddleware_Subdomain(t *testing.T) {
r := gin.New()
r.Use(DomainValidatorMiddleware("example.com"))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Host = "sub.example.com"
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403 for subdomain, got %d", w.Code)
}
}
func TestRateLimitMiddleware_FirstRequest(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(5, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200 for first request, got %d", w.Code)
}
}
func TestRateLimitMiddleware_WithinLimit(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(3, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
for i := range 3 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.2:12345"
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("request %d: expected 200, got %d", i+1, w.Code)
}
}
}
func TestRateLimitMiddleware_ExceedsLimit(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(2, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// First 2 should pass
for range 2 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.3:12345"
r.ServeHTTP(w, req)
}
// 3rd should be rate limited
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.3:12345"
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", w.Code)
}
}
func TestRateLimitMiddleware_XRealIP(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(2, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
for range 2 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.1")
r.ServeHTTP(w, req)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.1")
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429 with X-Real-IP, got %d", w.Code)
}
}
func TestRateLimitMiddleware_XForwardedFor(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(2, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
for range 2 {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3")
r.ServeHTTP(w, req)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.2, 10.0.0.3")
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429 with X-Forwarded-For, got %d", w.Code)
}
}
func TestRateLimitMiddleware_DifferentIPsIndependent(t *testing.T) {
r := gin.New()
r.Use(RateLimitMiddleware(1, time.Minute))
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Exhaust limit for IP 1
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.10")
r.ServeHTTP(w, req)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.10")
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("IP 1 second request should be 429, got %d", w.Code)
}
// IP 2 should still be allowed
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "10.0.0.20")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("different IP should get 200, got %d", w.Code)
}
}