mirror of
https://github.com/MHSanaei/3x-ui.git
synced 2026-06-06 05:04:22 +00:00
refactor(session): store user ID in session instead of full struct
Replaces storing the full User object in the session cookie with just the user ID. GetLoginUser now re-fetches the user from the database on every request so credential/permission changes take effect immediately without requiring a re-login. Includes a backward-compatible migration path for existing sessions that still carry the old struct payload.
This commit is contained in:
parent
cb962175c2
commit
ce88b0b432
2 changed files with 115 additions and 3 deletions
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mhsanaei/3x-ui/v3/database"
|
||||||
"github.com/mhsanaei/3x-ui/v3/database/model"
|
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||||
"github.com/mhsanaei/3x-ui/v3/logger"
|
"github.com/mhsanaei/3x-ui/v3/logger"
|
||||||
|
|
||||||
|
|
@ -27,7 +28,7 @@ func SetLoginUser(c *gin.Context, user *model.User) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s := sessions.Default(c)
|
s := sessions.Default(c)
|
||||||
s.Set(loginUserKey, *user)
|
s.Set(loginUserKey, user.Id)
|
||||||
return s.Save()
|
return s.Save()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -49,7 +50,7 @@ func GetLoginUser(c *gin.Context) *model.User {
|
||||||
if obj == nil {
|
if obj == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
user, ok := obj.(model.User)
|
userID, ok := sessionUserID(obj)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.Delete(loginUserKey)
|
s.Delete(loginUserKey)
|
||||||
if err := s.Save(); err != nil {
|
if err := s.Save(); err != nil {
|
||||||
|
|
@ -57,13 +58,77 @@ func GetLoginUser(c *gin.Context) *model.User {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &user
|
if legacyUserID, ok := legacySessionUserID(obj); ok {
|
||||||
|
s.Set(loginUserKey, legacyUserID)
|
||||||
|
if err := s.Save(); err != nil {
|
||||||
|
logger.Warning("session: failed to migrate legacy user payload:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
user, err := getUserByID(userID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning("session: failed to load user:", err)
|
||||||
|
s.Delete(loginUserKey)
|
||||||
|
if saveErr := s.Save(); saveErr != nil {
|
||||||
|
logger.Warning("session: failed to drop missing user:", saveErr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsLogin(c *gin.Context) bool {
|
func IsLogin(c *gin.Context) bool {
|
||||||
return GetLoginUser(c) != nil
|
return GetLoginUser(c) != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sessionUserID(obj any) (int, bool) {
|
||||||
|
switch v := obj.(type) {
|
||||||
|
case int:
|
||||||
|
return v, v > 0
|
||||||
|
case int64:
|
||||||
|
return int(v), v > 0
|
||||||
|
case int32:
|
||||||
|
return int(v), v > 0
|
||||||
|
case float64:
|
||||||
|
id := int(v)
|
||||||
|
return id, v == float64(id) && id > 0
|
||||||
|
case model.User:
|
||||||
|
return v.Id, v.Id > 0
|
||||||
|
case *model.User:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return v.Id, v.Id > 0
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func legacySessionUserID(obj any) (int, bool) {
|
||||||
|
switch v := obj.(type) {
|
||||||
|
case model.User:
|
||||||
|
return v.Id, v.Id > 0
|
||||||
|
case *model.User:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return v.Id, v.Id > 0
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserByID(id int) (*model.User, error) {
|
||||||
|
db := database.GetDB()
|
||||||
|
if db == nil {
|
||||||
|
return nil, http.ErrServerClosed
|
||||||
|
}
|
||||||
|
user := &model.User{}
|
||||||
|
if err := db.Model(model.User{}).Where("id = ?", id).First(user).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ClearSession(c *gin.Context) error {
|
func ClearSession(c *gin.Context) error {
|
||||||
s := sessions.Default(c)
|
s := sessions.Default(c)
|
||||||
s.Clear()
|
s.Clear()
|
||||||
|
|
|
||||||
47
web/session/session_test.go
Normal file
47
web/session/session_test.go
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mhsanaei/3x-ui/v3/database/model"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-contrib/sessions/cookie"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetLoginUserStoresOnlyUserID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(sessions.Sessions(sessionCookieName, cookie.NewStore([]byte("01234567890123456789012345678901"))))
|
||||||
|
router.GET("/", func(c *gin.Context) {
|
||||||
|
if err := SetLoginUser(c, &model.User{Id: 7, Username: "admin", Password: "hash"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := sessions.Default(c).Get(loginUserKey)
|
||||||
|
if got != 7 {
|
||||||
|
t.Fatalf("stored session payload = %#v, want user id only", got)
|
||||||
|
}
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionUserIDSupportsLegacyUserPayload(t *testing.T) {
|
||||||
|
id, ok := sessionUserID(model.User{Id: 11, Username: "admin", Password: "hash"})
|
||||||
|
if !ok || id != 11 {
|
||||||
|
t.Fatalf("legacy session payload resolved to (%d, %v), want (11, true)", id, ok)
|
||||||
|
}
|
||||||
|
id, ok = sessionUserID(&model.User{Id: 12, Username: "admin", Password: "hash"})
|
||||||
|
if !ok || id != 12 {
|
||||||
|
t.Fatalf("legacy pointer session payload resolved to (%d, %v), want (12, true)", id, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue