diff --git a/web/middleware/auth.go b/web/middleware/auth.go new file mode 100644 index 00000000..fd872068 --- /dev/null +++ b/web/middleware/auth.go @@ -0,0 +1,61 @@ +package middleware + +import ( + "net/http" + "os" + "strings" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +func AuthRequired() gin.HandlerFunc { + secret := os.Getenv("JWT_SECRET") + if secret == "" { + secret = "dev-secret-change-me" + } + return func(c *gin.Context) { + auth := c.GetHeader("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + tokenStr := strings.TrimPrefix(auth, "Bearer ") + token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + if err != nil || !token.Valid { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + if claims, ok := token.Claims.(jwt.MapClaims); ok { + if v, ok := claims["role"].(string); ok { + c.Set("role", v) + } + if v, ok := claims["id"].(float64); ok { + c.Set("user_id", int(v)) + } + c.Next() + return + } + c.AbortWithStatus(http.StatusUnauthorized) + } +} + +func RequireRole(roles ...string) gin.HandlerFunc { + return func(c *gin.Context) { + roleVal, ok := c.Get("role") + if !ok { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + role := roleVal.(string) + for _, r := range roles { + if r == role { + c.Next() + return + } + } + c.AbortWithStatus(http.StatusForbidden) + } +}