diff --git a/internal/controller/httpapi/v1/login.go b/internal/controller/httpapi/v1/login.go index 5b0bdb46e..630333ad7 100644 --- a/internal/controller/httpapi/v1/login.go +++ b/internal/controller/httpapi/v1/login.go @@ -2,6 +2,8 @@ package v1 import ( "context" + "errors" + "fmt" "net/http" "strings" "time" @@ -15,7 +17,10 @@ import ( "github.com/device-management-toolkit/console/pkg/consoleerrors" ) -var ErrLogin = consoleerrors.CreateConsoleError("LoginHandler") +var ( + ErrLogin = consoleerrors.CreateConsoleError("LoginHandler") + ErrUnexpectedSigningMethod = errors.New("unexpected signing method") +) type LoginRoute struct { Config *config.Config @@ -99,11 +104,17 @@ func (lr LoginRoute) JWTAuthMiddleware() gin.HandlerFunc { if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid access token"}) c.Abort() + + return } } else { claims := &jwt.MapClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, func(_ *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("%w: %v", ErrUnexpectedSigningMethod, token.Header["alg"]) + } + return []byte(lr.Config.JWTKey), nil }) diff --git a/internal/controller/ws/v1/redirect.go b/internal/controller/ws/v1/redirect.go index d917bb9b2..75d918336 100644 --- a/internal/controller/ws/v1/redirect.go +++ b/internal/controller/ws/v1/redirect.go @@ -2,6 +2,8 @@ package v1 import ( "compress/flate" + "errors" + "fmt" "net/http" "github.com/gin-gonic/gin" @@ -13,6 +15,8 @@ import ( "github.com/device-management-toolkit/console/pkg/logger" ) +var ErrUnexpectedSigningMethod = errors.New("unexpected signing method") + type RedirectRoutes struct { d devices.Feature l logger.Interface @@ -41,7 +45,11 @@ func (r *RedirectRoutes) websocketHandler(c *gin.Context) { claims := &jwt.MapClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, func(_ *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("%w: %v", ErrUnexpectedSigningMethod, token.Header["alg"]) + } + return []byte(config.ConsoleConfig.JWTKey), nil })