diff --git a/go.mod b/go.mod index 07a84a58..0a967ebf 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/mhsanaei/3x-ui/v2 -go 1.26.2 +go 1.26.3 require ( github.com/gin-contrib/gzip v1.2.6 @@ -19,7 +19,7 @@ require ( github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v4 v4.26.4 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/valyala/fasthttp v1.70.0 + github.com/valyala/fasthttp v1.71.0 github.com/xlzd/gotp v0.1.0 github.com/xtls/xray-core v1.260327.0 go.uber.org/atomic v1.11.0 @@ -95,7 +95,7 @@ require ( golang.org/x/tools v0.44.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260504160031-60b97b32f348 // indirect google.golang.org/protobuf v1.36.11 // indirect gvisor.dev/gvisor v0.0.0-20260122175437-89a5d21be8f0 // indirect lukechampine.com/blake3 v1.4.1 // indirect diff --git a/go.sum b/go.sum index 11e890f7..5ed1f3c2 100644 --- a/go.sum +++ b/go.sum @@ -185,8 +185,8 @@ github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.70.0 h1:LAhMGcWk13QZWm85+eg8ZBNbrq5mnkWFGbHMUJHIdXA= -github.com/valyala/fasthttp v1.70.0/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE= +github.com/valyala/fasthttp v1.71.0 h1:tepR7H+Guh9VUqxxcPggYi8R3lGUu2Rsdh+z7/FCY3k= +github.com/valyala/fasthttp v1.71.0/go.mod h1:z1sDUvOShhXq/C9mwH/fSm1Vb71tUJwmQdgkBrBNwnA= github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADTh4= github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= @@ -256,8 +256,8 @@ golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+Z golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260504160031-60b97b32f348 h1:pfIbyB44sWzHiCpRqIen67ZQnVXSfIxWrqUMk1qwODE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260504160031-60b97b32f348/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.81.0 h1:W3G9N3KQf3BU+YuCtGKJk0CmxQNbAISICD/9AORxLIw= google.golang.org/grpc v1.81.0/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= diff --git a/install.sh b/install.sh index 4a0ce55d..9eb1011b 100644 --- a/install.sh +++ b/install.sh @@ -678,13 +678,25 @@ config_after_install() { for ip_address in "${URL_lists[@]}"; do local response=$(curl -s -w "\n%{http_code}" --max-time 3 "${ip_address}" 2> /dev/null) local http_code=$(echo "$response" | tail -n1) - local ip_result=$(echo "$response" | head -n-1 | tr -d '[:space:]') - if [[ "${http_code}" == "200" && -n "${ip_result}" ]]; then + local ip_result=$(echo "$response" | head -n-1 | tr -d '[:space:]"') + if [[ "${http_code}" == "200" && "${ip_result}" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then server_ip="${ip_result}" break fi done + if [[ -z "$server_ip" ]]; then + echo -e "${yellow}Could not auto-detect server IP from any provider.${plain}" + while [[ -z "$server_ip" ]]; do + read -rp "Please enter your server's public IPv4 address: " server_ip + server_ip="${server_ip// /}" + if [[ ! "$server_ip" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo -e "${red}Invalid IPv4 address. Please try again.${plain}" + server_ip="" + fi + done + fi + if [[ ${#existing_webBasePath} -lt 4 ]]; then if [[ "$existing_hasDefaultCredential" == "true" ]]; then local config_webBasePath=$(gen_random_string 18) diff --git a/update.sh b/update.sh index 8336518b..1c9f1ae9 100755 --- a/update.sh +++ b/update.sh @@ -711,13 +711,25 @@ config_after_update() { for ip_address in "${URL_lists[@]}"; do local response=$(curl -s -w "\n%{http_code}" --max-time 3 "${ip_address}" 2> /dev/null) local http_code=$(echo "$response" | tail -n1) - local ip_result=$(echo "$response" | head -n-1 | tr -d '[:space:]') - if [[ "${http_code}" == "200" && -n "${ip_result}" ]]; then + local ip_result=$(echo "$response" | head -n-1 | tr -d '[:space:]"') + if [[ "${http_code}" == "200" && "${ip_result}" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then server_ip="${ip_result}" break fi done + if [[ -z "$server_ip" ]]; then + echo -e "${yellow}Could not auto-detect server IP from any provider.${plain}" + while [[ -z "$server_ip" ]]; do + read -rp "Please enter your server's public IPv4 address: " server_ip + server_ip="${server_ip// /}" + if [[ ! "$server_ip" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo -e "${red}Invalid IPv4 address. Please try again.${plain}" + server_ip="" + fi + done + fi + # Handle missing/short webBasePath if [[ ${#existing_webBasePath} -lt 4 ]]; then echo -e "${yellow}WebBasePath is missing or too short. Generating a new one...${plain}" @@ -737,12 +749,6 @@ config_after_update() { echo -e "${yellow}Let's Encrypt now supports both domains and IP addresses!${plain}" echo "" - if [[ -z "${server_ip}" ]]; then - echo -e "${red}Failed to detect server IP${plain}" - echo -e "${yellow}Please configure SSL manually using: x-ui${plain}" - return - fi - # Prompt and setup SSL (domain or IP) prompt_and_setup_ssl "${existing_port}" "${existing_webBasePath}" "${server_ip}" diff --git a/web/assets/js/axios-init.js b/web/assets/js/axios-init.js index f0b0f4be..c44c0647 100644 --- a/web/assets/js/axios-init.js +++ b/web/assets/js/axios-init.js @@ -3,6 +3,12 @@ axios.defaults.headers.common['X-Requested-With'] = 'XMLHttpRequest'; axios.interceptors.request.use( (config) => { + config.headers = config.headers || {}; + const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content'); + const method = (config.method || 'get').toUpperCase(); + if (csrfToken && !['GET', 'HEAD', 'OPTIONS', 'TRACE'].includes(method)) { + config.headers['X-CSRF-Token'] = csrfToken; + } if (config.data instanceof FormData) { config.headers['Content-Type'] = 'multipart/form-data'; } else { diff --git a/web/assets/js/model/reality_targets.js b/web/assets/js/model/reality_targets.js index 34777b5b..732e9f65 100644 --- a/web/assets/js/model/reality_targets.js +++ b/web/assets/js/model/reality_targets.js @@ -6,7 +6,6 @@ const REALITY_TARGETS = [ { target: 'www.nvidia.com:443', sni: 'www.nvidia.com' }, { target: 'www.amd.com:443', sni: 'www.amd.com' }, { target: 'www.intel.com:443', sni: 'www.intel.com' }, - { target: 'www.tesla.com:443', sni: 'www.tesla.com' }, { target: 'www.sony.com:443', sni: 'www.sony.com' } ]; diff --git a/web/assets/js/websocket.js b/web/assets/js/websocket.js index 64dd2769..f10507d5 100644 --- a/web/assets/js/websocket.js +++ b/web/assets/js/websocket.js @@ -104,15 +104,25 @@ class WebSocketClient { } this.ws = socket; + // Every handler must check `this.ws !== socket` first. A previous socket + // can still fire events (especially `close`) after we've moved on to a + // new one — e.g. connect() called while the old socket is in CLOSING + // state. Without the guard, a stale close would null out the freshly + // opened socket and silently break send(). socket.addEventListener('open', () => { + if (this.ws !== socket) return; this.isConnected = true; this.reconnectAttempts = 0; this.#emit('connected'); }); - socket.addEventListener('message', (event) => this.#onMessage(event)); + socket.addEventListener('message', (event) => { + if (this.ws !== socket) return; + this.#onMessage(event); + }); socket.addEventListener('error', (event) => { + if (this.ws !== socket) return; // Browsers fire 'error' before 'close' on failure. We surface it for // consumers (so polling fallbacks can engage) but don't log every blip // — bad networks would flood the console otherwise. @@ -120,6 +130,7 @@ class WebSocketClient { }); socket.addEventListener('close', () => { + if (this.ws !== socket) return; this.isConnected = false; this.ws = null; this.#emit('disconnected'); @@ -196,6 +207,10 @@ class WebSocketClient { this.reconnectTimer = setTimeout(() => { this.reconnectTimer = null; + // clearTimeout doesn't cancel a callback that has already fired but + // whose macrotask hasn't run yet — re-check shouldReconnect here so + // disconnect() called in that window can't be overridden. + if (!this.shouldReconnect) return; this.#openSocket(); }, delay); } diff --git a/web/controller/api.go b/web/controller/api.go index 57d2e4cb..e99a26d2 100644 --- a/web/controller/api.go +++ b/web/controller/api.go @@ -3,6 +3,7 @@ package controller import ( "net/http" + "github.com/mhsanaei/3x-ui/v2/web/middleware" "github.com/mhsanaei/3x-ui/v2/web/service" "github.com/mhsanaei/3x-ui/v2/web/session" @@ -39,6 +40,7 @@ func (a *APIController) initRouter(g *gin.RouterGroup, customGeo *service.Custom // Main API group api := g.Group("/panel/api") api.Use(a.checkAPIAuth) + api.Use(middleware.CSRFMiddleware()) // Inbounds API inbounds := api.Group("/inbounds") diff --git a/web/controller/index.go b/web/controller/index.go index 14791543..d3c58da8 100644 --- a/web/controller/index.go +++ b/web/controller/index.go @@ -1,12 +1,12 @@ package controller import ( - "fmt" "net/http" "text/template" "time" "github.com/mhsanaei/3x-ui/v2/logger" + "github.com/mhsanaei/3x-ui/v2/web/middleware" "github.com/mhsanaei/3x-ui/v2/web/service" "github.com/mhsanaei/3x-ui/v2/web/session" @@ -41,8 +41,8 @@ func (a *IndexController) initRouter(g *gin.RouterGroup) { g.GET("/", a.index) g.GET("/logout", a.logout) - g.POST("/login", a.login) - g.POST("/getTwoFactorEnable", a.getTwoFactorEnable) + g.POST("/login", middleware.CSRFMiddleware(), a.login) + g.POST("/getTwoFactorEnable", middleware.CSRFMiddleware(), a.getTwoFactorEnable) } // index handles the root route, redirecting logged-in users to the panel or showing the login page. @@ -71,28 +71,51 @@ func (a *IndexController) login(c *gin.Context) { return } - user, checkErr := a.userService.CheckUser(form.Username, form.Password, form.TwoFactorCode) - timeStr := time.Now().Format("2006-01-02 15:04:05") + remoteIP := getRemoteIp(c) safeUser := template.HTMLEscapeString(form.Username) - safePass := template.HTMLEscapeString(form.Password) - - if user == nil { - logger.Warningf("wrong username: \"%s\", password: \"%s\", IP: \"%s\"", safeUser, safePass, getRemoteIp(c)) - - notifyPass := safePass - - if checkErr != nil && checkErr.Error() == "invalid 2fa code" { - translatedError := a.tgbot.I18nBot("tgbot.messages.2faFailed") - notifyPass = fmt.Sprintf("*** (%s)", translatedError) - } - - a.tgbot.UserLoginNotify(safeUser, notifyPass, getRemoteIp(c), timeStr, 0) + timeStr := time.Now().Format("2006-01-02 15:04:05") + if blockedUntil, ok := defaultLoginLimiter.allow(remoteIP, form.Username); !ok { + reason := "too many failed attempts" + logger.Warningf("failed login: username=%q, IP=%q, reason=%q, blocked_until=%s", safeUser, remoteIP, reason, blockedUntil.Format(time.RFC3339)) + a.tgbot.UserLoginNotify(service.LoginAttempt{ + Username: safeUser, + IP: remoteIP, + Time: timeStr, + Status: service.LoginFail, + Reason: reason, + }) pureJsonMsg(c, http.StatusOK, false, I18nWeb(c, "pages.login.toasts.wrongUsernameOrPassword")) return } - logger.Infof("%s logged in successfully, Ip Address: %s\n", safeUser, getRemoteIp(c)) - a.tgbot.UserLoginNotify(safeUser, ``, getRemoteIp(c), timeStr, 1) + user, checkErr := a.userService.CheckUser(form.Username, form.Password, form.TwoFactorCode) + + if user == nil { + reason := loginFailureReason(checkErr) + if blockedUntil, blocked := defaultLoginLimiter.registerFailure(remoteIP, form.Username); blocked { + logger.Warningf("failed login: username=%q, IP=%q, reason=%q, blocked_until=%s", safeUser, remoteIP, reason, blockedUntil.Format(time.RFC3339)) + } else { + logger.Warningf("failed login: username=%q, IP=%q, reason=%q", safeUser, remoteIP, reason) + } + a.tgbot.UserLoginNotify(service.LoginAttempt{ + Username: safeUser, + IP: remoteIP, + Time: timeStr, + Status: service.LoginFail, + Reason: reason, + }) + pureJsonMsg(c, http.StatusOK, false, I18nWeb(c, "pages.login.toasts.wrongUsernameOrPassword")) + return + } + + defaultLoginLimiter.registerSuccess(remoteIP, form.Username) + logger.Infof("%s logged in successfully, Ip Address: %s\n", safeUser, remoteIP) + a.tgbot.UserLoginNotify(service.LoginAttempt{ + Username: safeUser, + IP: remoteIP, + Time: timeStr, + Status: service.LoginSuccess, + }) if err := session.SetLoginUser(c, user); err != nil { logger.Warning("Unable to save session:", err) @@ -103,6 +126,13 @@ func (a *IndexController) login(c *gin.Context) { jsonMsg(c, I18nWeb(c, "pages.login.toasts.successLogin"), nil) } +func loginFailureReason(err error) string { + if err != nil && err.Error() == "invalid 2fa code" { + return "invalid 2FA code" + } + return "invalid credentials" +} + // logout handles user logout by clearing the session and redirecting to the login page. func (a *IndexController) logout(c *gin.Context) { user := session.GetLoginUser(c) diff --git a/web/controller/login_limiter.go b/web/controller/login_limiter.go new file mode 100644 index 00000000..1694db99 --- /dev/null +++ b/web/controller/login_limiter.go @@ -0,0 +1,99 @@ +package controller + +import ( + "strings" + "sync" + "time" +) + +const ( + loginLimitMaxFailures = 5 + loginLimitWindow = 5 * time.Minute + loginLimitCooldown = 15 * time.Minute +) + +var defaultLoginLimiter = newLoginLimiter(loginLimitMaxFailures, loginLimitWindow, loginLimitCooldown) + +type loginLimiter struct { + mu sync.Mutex + now func() time.Time + maxFailures int + window time.Duration + cooldown time.Duration + attempts map[string]*loginLimitRecord +} + +type loginLimitRecord struct { + failures []time.Time + blockedUntil time.Time +} + +func newLoginLimiter(maxFailures int, window, cooldown time.Duration) *loginLimiter { + return &loginLimiter{ + now: time.Now, + maxFailures: maxFailures, + window: window, + cooldown: cooldown, + attempts: make(map[string]*loginLimitRecord), + } +} + +func (l *loginLimiter) allow(ip, username string) (time.Time, bool) { + l.mu.Lock() + defer l.mu.Unlock() + + key := loginLimitKey(ip, username) + record := l.attempts[key] + if record == nil { + return time.Time{}, true + } + now := l.now() + if now.Before(record.blockedUntil) { + return record.blockedUntil, false + } + record.blockedUntil = time.Time{} + record.failures = pruneLoginFailures(record.failures, now.Add(-l.window)) + if len(record.failures) == 0 { + delete(l.attempts, key) + } + return time.Time{}, true +} + +func (l *loginLimiter) registerFailure(ip, username string) (time.Time, bool) { + l.mu.Lock() + defer l.mu.Unlock() + + key := loginLimitKey(ip, username) + record := l.attempts[key] + if record == nil { + record = &loginLimitRecord{} + l.attempts[key] = record + } + now := l.now() + record.failures = pruneLoginFailures(record.failures, now.Add(-l.window)) + record.failures = append(record.failures, now) + if len(record.failures) >= l.maxFailures { + record.failures = nil + record.blockedUntil = now.Add(l.cooldown) + return record.blockedUntil, true + } + return time.Time{}, false +} + +func (l *loginLimiter) registerSuccess(ip, username string) { + l.mu.Lock() + defer l.mu.Unlock() + delete(l.attempts, loginLimitKey(ip, username)) +} + +func loginLimitKey(ip, username string) string { + return strings.TrimSpace(ip) + "\x00" + strings.ToLower(strings.TrimSpace(username)) +} + +func pruneLoginFailures(failures []time.Time, cutoff time.Time) []time.Time { + keepFrom := 0 + for keepFrom < len(failures) && failures[keepFrom].Before(cutoff) { + keepFrom++ + } + return failures[keepFrom:] +} diff --git a/web/controller/login_limiter_test.go b/web/controller/login_limiter_test.go new file mode 100644 index 00000000..c6f4ac6e --- /dev/null +++ b/web/controller/login_limiter_test.go @@ -0,0 +1,74 @@ +package controller + +import ( + "testing" + "time" +) + +func TestLoginLimiterBlocksAfterConfiguredFailures(t *testing.T) { + now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC) + limiter := newLoginLimiter(5, 5*time.Minute, 15*time.Minute) + limiter.now = func() time.Time { return now } + + for i := range 4 { + if _, blocked := limiter.registerFailure("192.0.2.10", "Admin"); blocked { + t.Fatalf("failure %d should not block yet", i+1) + } + if _, ok := limiter.allow("192.0.2.10", "admin"); !ok { + t.Fatalf("failure %d should still allow login attempts", i+1) + } + } + + blockedUntil, blocked := limiter.registerFailure("192.0.2.10", "ADMIN") + if !blocked { + t.Fatal("fifth failure should start cooldown") + } + if want := now.Add(15 * time.Minute); !blockedUntil.Equal(want) { + t.Fatalf("blocked until %s, want %s", blockedUntil, want) + } + if _, ok := limiter.allow("192.0.2.10", "admin"); ok { + t.Fatal("login should be blocked during cooldown") + } + + now = blockedUntil + if _, ok := limiter.allow("192.0.2.10", "admin"); !ok { + t.Fatal("login should be allowed after cooldown") + } +} + +func TestLoginLimiterPrunesOldFailuresAndResetsOnSuccess(t *testing.T) { + now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC) + limiter := newLoginLimiter(5, 5*time.Minute, 15*time.Minute) + limiter.now = func() time.Time { return now } + + for range 4 { + limiter.registerFailure("192.0.2.10", "admin") + } + now = now.Add(6 * time.Minute) + if _, blocked := limiter.registerFailure("192.0.2.10", "admin"); blocked { + t.Fatal("old failures should be pruned outside the rolling window") + } + + limiter.registerSuccess("192.0.2.10", "admin") + for i := range 4 { + if _, blocked := limiter.registerFailure("192.0.2.10", "admin"); blocked { + t.Fatalf("success should reset previous failures; failure %d blocked", i+1) + } + } +} + +func TestLoginLimiterSeparatesIPAndUsername(t *testing.T) { + now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC) + limiter := newLoginLimiter(5, 5*time.Minute, 15*time.Minute) + limiter.now = func() time.Time { return now } + + for range 5 { + limiter.registerFailure("192.0.2.10", "admin") + } + if _, ok := limiter.allow("192.0.2.11", "admin"); !ok { + t.Fatal("different IP should not be blocked") + } + if _, ok := limiter.allow("192.0.2.10", "other-admin"); !ok { + t.Fatal("different username should not be blocked") + } +} diff --git a/web/controller/util.go b/web/controller/util.go index e1d53ba6..070d2c70 100644 --- a/web/controller/util.go +++ b/web/controller/util.go @@ -10,6 +10,7 @@ import ( "github.com/mhsanaei/3x-ui/v2/config" "github.com/mhsanaei/3x-ui/v2/logger" "github.com/mhsanaei/3x-ui/v2/web/entity" + "github.com/mhsanaei/3x-ui/v2/web/session" "github.com/gin-gonic/gin" ) @@ -121,6 +122,12 @@ func html(c *gin.Context, name string, title string, data gin.H) { data = gin.H{} } data["title"] = title + csrfToken, err := session.EnsureCSRFToken(c) + if err != nil { + logger.Warning("Unable to create CSRF token:", err) + } else { + data["csrf_token"] = csrfToken + } host := c.GetHeader("X-Forwarded-Host") if host == "" { host = c.GetHeader("X-Real-IP") diff --git a/web/controller/websocket.go b/web/controller/websocket.go index 2e9fbca0..f2b4ee0e 100644 --- a/web/controller/websocket.go +++ b/web/controller/websocket.go @@ -5,25 +5,15 @@ import ( "net/http" "net/url" "strings" - "time" - "github.com/google/uuid" "github.com/mhsanaei/3x-ui/v2/logger" - "github.com/mhsanaei/3x-ui/v2/util/common" + "github.com/mhsanaei/3x-ui/v2/web/service" "github.com/mhsanaei/3x-ui/v2/web/session" - "github.com/mhsanaei/3x-ui/v2/web/websocket" "github.com/gin-gonic/gin" ws "github.com/gorilla/websocket" ) -const ( - writeWait = 10 * time.Second - pongWait = 60 * time.Second - pingPeriod = (pongWait * 9) / 10 - clientReadLimit = 512 -) - var upgrader = ws.Upgrader{ ReadBufferSize: 32768, WriteBufferSize: 32768, @@ -57,18 +47,21 @@ func checkSameOrigin(r *http.Request) bool { return strings.EqualFold(u.Hostname(), host) } -// WebSocketController handles WebSocket connections for real-time updates. +// WebSocketController handles the HTTP→WebSocket upgrade for real-time updates. +// All per-connection lifecycle (pumps, hub registration) lives in +// service.WebSocketService — this controller is HTTP-layer only. type WebSocketController struct { BaseController - hub *websocket.Hub + service *service.WebSocketService } -// NewWebSocketController creates a new WebSocket controller. -func NewWebSocketController(hub *websocket.Hub) *WebSocketController { - return &WebSocketController{hub: hub} +// NewWebSocketController creates a controller wired to the given service. +func NewWebSocketController(svc *service.WebSocketService) *WebSocketController { + return &WebSocketController{service: svc} } -// HandleWebSocket upgrades the HTTP connection and starts the read/write pumps. +// HandleWebSocket authenticates the request, upgrades the HTTP connection, and +// hands ownership of the connection off to the service. func (w *WebSocketController) HandleWebSocket(c *gin.Context) { if !session.IsLogin(c) { logger.Warningf("Unauthorized WebSocket connection attempt from %s", getRemoteIp(c)) @@ -82,71 +75,5 @@ func (w *WebSocketController) HandleWebSocket(c *gin.Context) { return } - client := websocket.NewClient(uuid.New().String()) - w.hub.Register(client) - logger.Debugf("WebSocket client %s registered from %s", client.ID, getRemoteIp(c)) - - go w.writePump(client, conn) - go w.readPump(client, conn) -} - -// readPump consumes inbound frames so the gorilla deadline/pong machinery keeps -// running. Clients send no commands today; frames are discarded. -func (w *WebSocketController) readPump(client *websocket.Client, conn *ws.Conn) { - defer func() { - if r := common.Recover("WebSocket readPump panic"); r != nil { - logger.Error("WebSocket readPump panic recovered:", r) - } - w.hub.Unregister(client) - conn.Close() - }() - - conn.SetReadLimit(clientReadLimit) - conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { - return conn.SetReadDeadline(time.Now().Add(pongWait)) - }) - - for { - if _, _, err := conn.ReadMessage(); err != nil { - if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseAbnormalClosure) { - logger.Debugf("WebSocket read error for client %s: %v", client.ID, err) - } - return - } - } -} - -// writePump pushes hub messages to the connection and emits keepalive pings. -func (w *WebSocketController) writePump(client *websocket.Client, conn *ws.Conn) { - ticker := time.NewTicker(pingPeriod) - defer func() { - if r := common.Recover("WebSocket writePump panic"); r != nil { - logger.Error("WebSocket writePump panic recovered:", r) - } - ticker.Stop() - conn.Close() - }() - - for { - select { - case msg, ok := <-client.Send: - conn.SetWriteDeadline(time.Now().Add(writeWait)) - if !ok { - conn.WriteMessage(ws.CloseMessage, []byte{}) - return - } - if err := conn.WriteMessage(ws.TextMessage, msg); err != nil { - logger.Debugf("WebSocket write error for client %s: %v", client.ID, err) - return - } - - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := conn.WriteMessage(ws.PingMessage, nil); err != nil { - logger.Debugf("WebSocket ping error for client %s: %v", client.ID, err) - return - } - } - } + w.service.HandleConnection(conn, getRemoteIp(c)) } diff --git a/web/controller/xui.go b/web/controller/xui.go index 51502900..afbbeb71 100644 --- a/web/controller/xui.go +++ b/web/controller/xui.go @@ -1,6 +1,8 @@ package controller import ( + "github.com/mhsanaei/3x-ui/v2/web/middleware" + "github.com/gin-gonic/gin" ) @@ -23,6 +25,7 @@ func NewXUIController(g *gin.RouterGroup) *XUIController { func (a *XUIController) initRouter(g *gin.RouterGroup) { g = g.Group("/panel") g.Use(a.checkLogin) + g.Use(middleware.CSRFMiddleware()) g.GET("/", a.index) g.GET("/inbounds", a.inbounds) diff --git a/web/html/common/page.html b/web/html/common/page.html index 47b2b654..13f5d64c 100644 --- a/web/html/common/page.html +++ b/web/html/common/page.html @@ -7,6 +7,7 @@ + {{ if .csrf_token }}{{ end }}