diff --git a/database/model/model.go b/database/model/model.go index b3764e5e..2d44d104 100644 --- a/database/model/model.go +++ b/database/model/model.go @@ -21,12 +21,8 @@ const ( Shadowsocks Protocol = "shadowsocks" Mixed Protocol = "mixed" WireGuard Protocol = "wireguard" - // UI stores Hysteria v1 and v2 both as "hysteria" and uses - // settings.version to discriminate. Imports from outside the panel - // can carry the literal "hysteria2" string, so IsHysteria below - // accepts both. - Hysteria Protocol = "hysteria" - Hysteria2 Protocol = "hysteria2" + Hysteria Protocol = "hysteria" + Hysteria2 Protocol = "hysteria2" ) // IsHysteria returns true for both "hysteria" and "hysteria2". @@ -38,9 +34,10 @@ func IsHysteria(p Protocol) bool { // User represents a user account in the 3x-ui panel. type User struct { - Id int `json:"id" gorm:"primaryKey;autoIncrement"` - Username string `json:"username"` - Password string `json:"password"` + Id int `json:"id" gorm:"primaryKey;autoIncrement"` + Username string `json:"username"` + Password string `json:"password"` + LoginEpoch int64 `json:"-" gorm:"default:0"` } // Inbound represents an Xray inbound configuration with traffic statistics and settings. @@ -66,12 +63,7 @@ type Inbound struct { StreamSettings string `json:"streamSettings" form:"streamSettings"` Tag string `json:"tag" form:"tag" gorm:"unique"` Sniffing string `json:"sniffing" form:"sniffing"` - - // NodeID points at the remote panel (Node) where this inbound's xray - // actually runs. NULL means the inbound runs on the local xray (the - // pre-multi-node behaviour). Existing rows migrate to NULL with no - // backfill. - NodeID *int `json:"nodeId,omitempty" form:"nodeId" gorm:"index"` + NodeID *int `json:"nodeId,omitempty" form:"nodeId" gorm:"index"` } // OutboundTraffics tracks traffic statistics for Xray outbound connections. diff --git a/web/service/user.go b/web/service/user.go index e5544d5a..28970e7b 100644 --- a/web/service/user.go +++ b/web/service/user.go @@ -122,7 +122,11 @@ func (s *UserService) UpdateUser(id int, username string, password string) error return db.Model(model.User{}). Where("id = ?", id). - Updates(map[string]any{"username": username, "password": hashedPassword}). + Updates(map[string]any{ + "username": username, + "password": hashedPassword, + "login_epoch": gorm.Expr("login_epoch + 1"), + }). Error } @@ -150,5 +154,6 @@ func (s *UserService) UpdateFirstUser(username string, password string) error { } user.Username = username user.Password = hashedPassword + user.LoginEpoch++ return db.Save(user).Error } diff --git a/web/session/session.go b/web/session/session.go index b7340922..cb3c6b09 100644 --- a/web/session/session.go +++ b/web/session/session.go @@ -15,6 +15,7 @@ import ( const ( loginUserKey = "LOGIN_USER" + loginEpochKey = "LOGIN_EPOCH" apiAuthUserKey = "api_auth_user" sessionCookieName = "3x-ui" ) @@ -29,6 +30,7 @@ func SetLoginUser(c *gin.Context, user *model.User) error { } s := sessions.Default(c) s.Set(loginUserKey, user.Id) + s.Set(loginEpochKey, user.LoginEpoch) return s.Save() } @@ -53,6 +55,7 @@ func GetLoginUser(c *gin.Context) *model.User { userID, ok := sessionUserID(obj) if !ok { s.Delete(loginUserKey) + s.Delete(loginEpochKey) if err := s.Save(); err != nil { logger.Warning("session: failed to drop stale user payload:", err) } @@ -68,14 +71,41 @@ func GetLoginUser(c *gin.Context) *model.User { if err != nil { logger.Warning("session: failed to load user:", err) s.Delete(loginUserKey) + s.Delete(loginEpochKey) if saveErr := s.Save(); saveErr != nil { logger.Warning("session: failed to drop missing user:", saveErr) } return nil } + if !sessionEpochMatches(s.Get(loginEpochKey), user.LoginEpoch) { + s.Delete(loginUserKey) + s.Delete(loginEpochKey) + if saveErr := s.Save(); saveErr != nil { + logger.Warning("session: failed to drop stale epoch:", saveErr) + } + return nil + } return user } +func sessionEpochMatches(cookieVal any, userEpoch int64) bool { + var got int64 + switch v := cookieVal.(type) { + case nil: + case int64: + got = v + case int: + got = int64(v) + case int32: + got = int64(v) + case float64: + got = int64(v) + default: + return false + } + return got == userEpoch +} + func IsLogin(c *gin.Context) bool { return GetLoginUser(c) != nil }