diff --git a/pkg/middleware/middleware_jwt_auth_test.go b/pkg/middleware/middleware_jwt_auth_test.go index 3af4eaf88b6..3517b2ec0d5 100644 --- a/pkg/middleware/middleware_jwt_auth_test.go +++ b/pkg/middleware/middleware_jwt_auth_test.go @@ -79,7 +79,7 @@ func TestMiddlewareJWTAuth(t *testing.T) { assert.Equal(t, myUsername, sc.context.Login) list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context()) require.NotNil(t, list) - require.EqualValues(t, []string{sc.cfg.JWTAuthHeaderName}, list.Items) + require.EqualValues(t, []string{"Authorization", sc.cfg.JWTAuthHeaderName}, list.Items) }, configure, configureUsernameClaim) middlewareScenario(t, "Valid token with bearer in authorization header", func(t *testing.T, sc *scenarioContext) { diff --git a/pkg/services/contexthandler/auth_jwt.go b/pkg/services/contexthandler/auth_jwt.go index a0a5b9750a9..04cc6d034f1 100644 --- a/pkg/services/contexthandler/auth_jwt.go +++ b/pkg/services/contexthandler/auth_jwt.go @@ -150,9 +150,6 @@ func (h *ContextHandler) initContextWithJWT(ctx *contextmodel.ReqContext, orgId return true } - newCtx := WithAuthHTTPHeader(ctx.Req.Context(), h.Cfg.JWTAuthHeaderName) - *ctx.Req = *ctx.Req.WithContext(newCtx) - ctx.SignedInUser = queryResult ctx.IsSignedIn = true diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 530047f2e25..0087ca02e9c 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -129,6 +129,8 @@ func (h *ContextHandler) Middleware(next http.Handler) http.Handler { // Inject ReqContext into http.Request.Context *r = *r.WithContext(context.WithValue(ctx, reqContextKey{}, reqContext)) + // store list of possible auth header in context + *reqContext.Req = *reqContext.Req.WithContext(WithAuthHTTPHeaders(reqContext.Req.Context(), h.Cfg)) traceID := tracing.TraceIDFromContext(mContext.Req.Context(), false) if traceID != "" { @@ -150,7 +152,6 @@ func (h *ContextHandler) Middleware(next http.Handler) http.Handler { reqContext.SignedInUser = identity.SignedInUser() reqContext.AllowAnonymous = identity.IsAnonymous reqContext.IsRenderCall = identity.AuthModule == login.RenderModule - // FIXME (kallep): Add auth headers used to context } } else { const headerName = "X-Grafana-Org-Id" @@ -306,9 +307,6 @@ func (h *ContextHandler) initContextWithAPIKey(reqContext *contextmodel.ReqConte _, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAPIKey") defer span.End() - ctx := WithAuthHTTPHeader(reqContext.Req.Context(), "Authorization") - *reqContext.Req = *reqContext.Req.WithContext(ctx) - var ( apiKey *apikey.APIKey errKey error @@ -419,15 +417,12 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *contextmodel.ReqCo return true } - ctx := WithAuthHTTPHeader(reqContext.Req.Context(), "Authorization") - *reqContext.Req = *reqContext.Req.WithContext(ctx) - authQuery := login.LoginUserQuery{ Username: username, Password: password, Cfg: h.Cfg, } - if err := h.authenticator.AuthenticateUser(ctx, &authQuery); err != nil { + if err := h.authenticator.AuthenticateUser(reqContext.Req.Context(), &authQuery); err != nil { reqContext.Logger.Debug( "Failed to authorize the user", "username", username, @@ -444,7 +439,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *contextmodel.ReqCo usr := authQuery.User query := user.GetSignedInUserQuery{UserID: usr.ID, OrgID: orgID} - queryResult, err := h.userService.GetSignedInUserWithCacheCtx(ctx, &query) + queryResult, err := h.userService.GetSignedInUserWithCacheCtx(reqContext.Req.Context(), &query) if err != nil { reqContext.Logger.Error( "Failed at user signed in", @@ -713,15 +708,6 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *contextmodel.ReqCo logger.Debug("Successfully got user info", "userID", user.UserID, "username", user.Login) - ctx := WithAuthHTTPHeader(reqContext.Req.Context(), h.Cfg.AuthProxyHeaderName) - for _, header := range h.Cfg.AuthProxyHeaders { - if header != "" { - ctx = WithAuthHTTPHeader(ctx, header) - } - } - - *reqContext.Req = *reqContext.Req.WithContext(ctx) - // Add user info to context reqContext.SignedInUser = user reqContext.IsSignedIn = true @@ -752,20 +738,35 @@ type AuthHTTPHeaderList struct { Items []string } -// WithAuthHTTPHeader returns a copy of parent in which the named HTTP header will be included +// WithAuthHTTPHeaders returns a new context in which all possible configured auth header will be included // and later retrievable by AuthHTTPHeaderListFromContext. -func WithAuthHTTPHeader(parent context.Context, name string) context.Context { - list := AuthHTTPHeaderListFromContext(parent) - +func WithAuthHTTPHeaders(ctx context.Context, cfg *setting.Cfg) context.Context { + list := AuthHTTPHeaderListFromContext(ctx) if list == nil { list = &AuthHTTPHeaderList{ Items: []string{}, } } - list.Items = append(list.Items, name) + // used by basic auth, api keys and potentially jwt auth + list.Items = append(list.Items, "Authorization") - return context.WithValue(parent, authHTTPHeaderListKey, list) + // if jwt is enabled we add it to the list. We can ignore in case it is set to Authorization + if cfg.JWTAuthEnabled && cfg.JWTAuthHeaderName != "" && cfg.JWTAuthHeaderName != "Authorization" { + list.Items = append(list.Items, cfg.JWTAuthHeaderName) + } + + // if auth proxy is enabled add the main proxy header and all configured headers + if cfg.AuthProxyEnabled { + list.Items = append(list.Items, cfg.AuthProxyHeaderName) + for _, header := range cfg.AuthProxyHeaders { + if header != "" { + list.Items = append(list.Items, header) + } + } + } + + return context.WithValue(ctx, authHTTPHeaderListKey, list) } // AuthHTTPHeaderListFromContext returns the AuthHTTPHeaderList in a context.Context, if any, diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go index 6a4975203ca..2aa709ff987 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go @@ -4,11 +4,13 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana/pkg/plugins/manager/client/clienttest" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/user" - "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/setting" ) func TestClearAuthHeadersMiddleware(t *testing.T) { @@ -113,10 +115,8 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), ) - const customHeader = "X-Custom" - req.Header.Set(customHeader, "val") - ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) - req = req.WithContext(ctx) + req := req.WithContext(contexthandler.WithAuthHTTPHeaders(req.Context(), setting.NewCfg())) + req.Header.Set("Authorization", "val") const otherHeader = "X-Other" req.Header.Set(otherHeader, "test") @@ -165,10 +165,8 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { clienttest.WithMiddlewares(NewClearAuthHeadersMiddleware()), ) - const customHeader = "x-Custom" - req.Header.Set(customHeader, "val") - ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) - req = req.WithContext(ctx) + req := req.WithContext(contexthandler.WithAuthHTTPHeaders(req.Context(), setting.NewCfg())) + req.Header.Set("Authorization", "val") const otherHeader = "x-Other" req.Header.Set(otherHeader, "test") diff --git a/pkg/util/proxyutil/reverse_proxy_test.go b/pkg/util/proxyutil/reverse_proxy_test.go index 0bab66b9b75..602575d660f 100644 --- a/pkg/util/proxyutil/reverse_proxy_test.go +++ b/pkg/util/proxyutil/reverse_proxy_test.go @@ -12,6 +12,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/setting" ) func TestReverseProxy(t *testing.T) { @@ -32,10 +33,8 @@ func TestReverseProxy(t *testing.T) { req.Header.Set("Referer", "https://test.com/api") req.RemoteAddr = "10.0.0.1" - const customHeader = "X-CUSTOM" - req.Header.Set(customHeader, "val") - ctx := contexthandler.WithAuthHTTPHeader(req.Context(), customHeader) - req = req.WithContext(ctx) + req = req.WithContext(contexthandler.WithAuthHTTPHeaders(req.Context(), setting.NewCfg())) + req.Header.Set("Authorization", "val") rp := NewReverseProxy(log.New("test"), func(req *http.Request) { req.Header.Set("X-KEY", "value") @@ -57,7 +56,7 @@ func TestReverseProxy(t *testing.T) { require.Empty(t, resp.Cookies()) require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy")) require.NoError(t, resp.Body.Close()) - require.Empty(t, actualReq.Header.Get(customHeader)) + require.Empty(t, actualReq.Header.Get("Authorization")) }) t.Run("When proxying a request using WithModifyResponse should call it before default ModifyResponse func", func(t *testing.T) {