diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go index 7a942eb4661..5c93e1da565 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware.go @@ -67,3 +67,33 @@ func (m *ClearAuthHeadersMiddleware) CheckHealth(ctx context.Context, req *backe return m.BaseHandler.CheckHealth(ctx, req) } + +func (m *ClearAuthHeadersMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if req == nil { + return m.BaseHandler.SubscribeStream(ctx, req) + } + + m.clearHeaders(ctx, req) + + return m.BaseHandler.SubscribeStream(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if req == nil { + return m.BaseHandler.PublishStream(ctx, req) + } + + m.clearHeaders(ctx, req) + + return m.BaseHandler.PublishStream(ctx, req) +} + +func (m *ClearAuthHeadersMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if req == nil { + return m.BaseHandler.RunStream(ctx, req, sender) + } + + m.clearHeaders(ctx, req) + + return m.BaseHandler.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go index a02e3d3500d..5ca047881e3 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/clear_auth_headers_middleware_test.go @@ -32,7 +32,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, } - t.Run("Should not attach delete headers middleware when calling QueryData", func(t *testing.T) { + t.Run("No auth headers to clear when calling QueryData", func(t *testing.T) { _, err = cdt.MiddlewareHandler.QueryData(req.Context(), &backend.QueryDataRequest{ PluginContext: pluginCtx, Headers: map[string]string{otherHeader: "test"}, @@ -40,9 +40,10 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) + require.Empty(t, cdt.QueryDataReq.GetHTTPHeaders()) }) - t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) { + t.Run("No auth headers to clear when calling CallResource", func(t *testing.T) { err = cdt.MiddlewareHandler.CallResource(req.Context(), &backend.CallResourceRequest{ PluginContext: pluginCtx, Headers: map[string][]string{otherHeader: {"test"}}, @@ -50,9 +51,10 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) + require.Equal(t, http.Header{http.CanonicalHeaderKey(otherHeader): {"test"}}, cdt.CallResourceReq.GetHTTPHeaders()) }) - t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + t.Run("No auth headers to clear when calling CheckHealth", func(t *testing.T) { _, err = cdt.MiddlewareHandler.CheckHealth(req.Context(), &backend.CheckHealthRequest{ PluginContext: pluginCtx, Headers: map[string]string{otherHeader: "test"}, @@ -60,6 +62,40 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) + require.Empty(t, cdt.CheckHealthReq.GetHTTPHeaders()) + }) + + t.Run("No auth headers to clear when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Len(t, cdt.SubscribeStreamReq.Headers, 1) + require.Empty(t, cdt.SubscribeStreamReq.GetHTTPHeaders()) + }) + + t.Run("No auth headers to clear when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Len(t, cdt.PublishStreamReq.Headers, 1) + require.Empty(t, cdt.PublishStreamReq.GetHTTPHeaders()) + }) + + t.Run("No auth headers to clear when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Len(t, cdt.RunStreamReq.Headers, 1) + require.Empty(t, cdt.RunStreamReq.GetHTTPHeaders()) }) }) @@ -73,7 +109,7 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { AppInstanceSettings: &backend.AppInstanceSettings{}, } - t.Run("Should not attach delete headers middleware when calling QueryData", func(t *testing.T) { + t.Run("No auth headers to clear when calling QueryData", func(t *testing.T) { _, err = cdt.MiddlewareHandler.QueryData(req.Context(), &backend.QueryDataRequest{ PluginContext: pluginCtx, Headers: map[string]string{otherHeader: "test"}, @@ -81,9 +117,11 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) + require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + require.Empty(t, cdt.QueryDataReq.GetHTTPHeaders()) }) - t.Run("Should not attach delete headers middleware when calling CallResource", func(t *testing.T) { + t.Run("No auth headers to clear when calling CallResource", func(t *testing.T) { err = cdt.MiddlewareHandler.CallResource(req.Context(), &backend.CallResourceRequest{ PluginContext: pluginCtx, Headers: map[string][]string{otherHeader: {"test"}}, @@ -91,9 +129,11 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) + require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader]) + require.Equal(t, http.Header{http.CanonicalHeaderKey(otherHeader): {"test"}}, cdt.CallResourceReq.GetHTTPHeaders()) }) - t.Run("Should not attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + t.Run("No auth headers to clear when calling CheckHealth", func(t *testing.T) { _, err = cdt.MiddlewareHandler.CheckHealth(req.Context(), &backend.CheckHealthRequest{ PluginContext: pluginCtx, Headers: map[string]string{otherHeader: "test"}, @@ -101,6 +141,44 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) + require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + require.Empty(t, cdt.CheckHealthReq.GetHTTPHeaders()) + }) + + t.Run("No auth headers to clear when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Len(t, cdt.SubscribeStreamReq.Headers, 1) + require.Equal(t, "test", cdt.SubscribeStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.SubscribeStreamReq.GetHTTPHeaders()) + }) + + t.Run("No auth headers to clear when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Len(t, cdt.PublishStreamReq.Headers, 1) + require.Equal(t, "test", cdt.PublishStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.PublishStreamReq.GetHTTPHeaders()) + }) + + t.Run("No auth headers to clear when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{otherHeader: "test"}, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Len(t, cdt.RunStreamReq.Headers, 1) + require.Equal(t, "test", cdt.RunStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.RunStreamReq.GetHTTPHeaders()) }) }) }) @@ -116,46 +194,105 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { ) req := req.WithContext(contexthandler.WithAuthHTTPHeaders(req.Context(), setting.NewCfg())) - req.Header.Set("Authorization", "val") - - const otherHeader = "X-Other" - req.Header.Set(otherHeader, "test") pluginCtx := backend.PluginContext{ DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, } - t.Run("Should attach delete headers middleware when calling QueryData", func(t *testing.T) { + t.Run("Should clear auth headers when calling QueryData", func(t *testing.T) { _, err = cdt.MiddlewareHandler.QueryData(req.Context(), &backend.QueryDataRequest{ PluginContext: pluginCtx, - Headers: map[string]string{otherHeader: "test"}, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, }) require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + require.Empty(t, cdt.QueryDataReq.GetHTTPHeaders()) }) - t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) { + t.Run("Should clear auth headers when calling CallResource", func(t *testing.T) { err = cdt.MiddlewareHandler.CallResource(req.Context(), &backend.CallResourceRequest{ PluginContext: pluginCtx, - Headers: map[string][]string{otherHeader: {"test"}}, + Headers: map[string][]string{ + otherHeader: {"test"}, + "Authorization": {"secret"}, + "X-Grafana-Device-Id": {"secret"}, + }, }, nopCallResourceSender) require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader]) + require.Equal(t, "test", cdt.CallResourceReq.GetHTTPHeader(otherHeader)) }) - t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + t.Run("Should clear auth headers when calling CheckHealth", func(t *testing.T) { _, err = cdt.MiddlewareHandler.CheckHealth(req.Context(), &backend.CheckHealthRequest{ PluginContext: pluginCtx, - Headers: map[string]string{otherHeader: "test"}, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, }) require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + require.Empty(t, cdt.CheckHealthReq.GetHTTPHeaders()) + }) + + t.Run("Should clear auth headers when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Len(t, cdt.SubscribeStreamReq.Headers, 1) + require.Equal(t, "test", cdt.SubscribeStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.SubscribeStreamReq.GetHTTPHeaders()) + }) + + t.Run("Should clear auth headers when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Len(t, cdt.PublishStreamReq.Headers, 1) + require.Equal(t, "test", cdt.PublishStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.PublishStreamReq.GetHTTPHeaders()) + }) + + t.Run("Should clear auth headers when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Len(t, cdt.RunStreamReq.Headers, 1) + require.Equal(t, "test", cdt.RunStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.RunStreamReq.GetHTTPHeaders()) }) }) @@ -175,37 +312,100 @@ func TestClearAuthHeadersMiddleware(t *testing.T) { AppInstanceSettings: &backend.AppInstanceSettings{}, } - t.Run("Should attach delete headers middleware when calling QueryData", func(t *testing.T) { + t.Run("Should clear auth headers when calling QueryData", func(t *testing.T) { _, err = cdt.MiddlewareHandler.QueryData(req.Context(), &backend.QueryDataRequest{ PluginContext: pluginCtx, - Headers: map[string]string{otherHeader: "test"}, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, }) require.NoError(t, err) require.NotNil(t, cdt.QueryDataReq) require.Len(t, cdt.QueryDataReq.Headers, 1) require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) + require.Empty(t, cdt.QueryDataReq.GetHTTPHeaders()) }) - t.Run("Should attach delete headers middleware when calling CallResource", func(t *testing.T) { + t.Run("Should clear auth headers when calling CallResource", func(t *testing.T) { err = cdt.MiddlewareHandler.CallResource(req.Context(), &backend.CallResourceRequest{ PluginContext: pluginCtx, - Headers: map[string][]string{otherHeader: {"test"}}, + Headers: map[string][]string{ + otherHeader: {"test"}, + "Authorization": {"secret"}, + "X-Grafana-Device-Id": {"secret"}, + }, }, nopCallResourceSender) require.NoError(t, err) require.NotNil(t, cdt.CallResourceReq) require.Len(t, cdt.CallResourceReq.Headers, 1) require.Equal(t, []string{"test"}, cdt.CallResourceReq.Headers[otherHeader]) + require.Equal(t, "test", cdt.CallResourceReq.GetHTTPHeader(otherHeader)) }) - t.Run("Should attach delete headers middleware when calling CheckHealth", func(t *testing.T) { + t.Run("Should clear auth headers when calling CheckHealth", func(t *testing.T) { _, err = cdt.MiddlewareHandler.CheckHealth(req.Context(), &backend.CheckHealthRequest{ PluginContext: pluginCtx, - Headers: map[string]string{otherHeader: "test"}, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, }) require.NoError(t, err) require.NotNil(t, cdt.CheckHealthReq) require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) + require.Empty(t, cdt.CheckHealthReq.GetHTTPHeaders()) + }) + + t.Run("Should clear auth headers when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Len(t, cdt.SubscribeStreamReq.Headers, 1) + require.Equal(t, "test", cdt.SubscribeStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.SubscribeStreamReq.GetHTTPHeaders()) + }) + + t.Run("Should clear auth headers when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Len(t, cdt.PublishStreamReq.Headers, 1) + require.Equal(t, "test", cdt.PublishStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.PublishStreamReq.GetHTTPHeaders()) + }) + + t.Run("Should clear auth headers when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{ + otherHeader: "test", + "Authorization": "secret", + "X-Grafana-Device-Id": "secret", + }, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Len(t, cdt.RunStreamReq.Headers, 1) + require.Equal(t, "test", cdt.RunStreamReq.Headers[otherHeader]) + require.Empty(t, cdt.RunStreamReq.GetHTTPHeaders()) }) }) }) diff --git a/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware.go index 42764a7641f..311673a4877 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware.go @@ -76,3 +76,42 @@ func (m *ForwardIDMiddleware) CheckHealth(ctx context.Context, req *backend.Chec return m.BaseHandler.CheckHealth(ctx, req) } + +func (m *ForwardIDMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if req == nil { + return m.BaseHandler.SubscribeStream(ctx, req) + } + + err := m.applyToken(ctx, req.PluginContext, req) + if err != nil { + return nil, err + } + + return m.BaseHandler.SubscribeStream(ctx, req) +} + +func (m *ForwardIDMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if req == nil { + return m.BaseHandler.PublishStream(ctx, req) + } + + err := m.applyToken(ctx, req.PluginContext, req) + if err != nil { + return nil, err + } + + return m.BaseHandler.PublishStream(ctx, req) +} + +func (m *ForwardIDMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if req == nil { + return m.BaseHandler.RunStream(ctx, req, sender) + } + + err := m.applyToken(ctx, req.PluginContext, req) + if err != nil { + return err + } + + return m.BaseHandler.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware_test.go index bc8532e7f75..b8acf0f60b3 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/forward_id_middleware_test.go @@ -16,11 +16,68 @@ import ( ) func TestForwardIDMiddleware(t *testing.T) { - pluginContext := backend.PluginContext{ - DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, - } + t.Run("When not signed in", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + }) - t.Run("Should set forwarded id header if present", func(t *testing.T) { + t.Run("And requests are for a datasource", func(t *testing.T) { + pluginContext := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, + } + + t.Run("Should not set forwarded id header if not present for QueryData", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.QueryData(ctx, &backend.QueryDataRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Empty(t, cdt.QueryDataReq.GetHTTPHeaders()) + }) + + t.Run("Should not set forwarded id header if not present for CallResource", func(t *testing.T) { + err := cdt.MiddlewareHandler.CallResource(ctx, &backend.CallResourceRequest{ + PluginContext: pluginContext, + }, nopCallResourceSender) + require.NoError(t, err) + require.Empty(t, cdt.CallResourceReq.GetHTTPHeaders()) + }) + + t.Run("Should not set forwarded id header if not present for CheckHealth", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.CheckHealth(ctx, &backend.CheckHealthRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Empty(t, cdt.CheckHealthReq.GetHTTPHeaders()) + }) + + t.Run("Should not set forwarded id header if not present for SubscribeStream", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.SubscribeStream(ctx, &backend.SubscribeStreamRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Empty(t, cdt.SubscribeStreamReq.GetHTTPHeaders()) + }) + + t.Run("Should not set forwarded id header if not present for PublishStream", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.PublishStream(ctx, &backend.PublishStreamRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Empty(t, cdt.PublishStreamReq.GetHTTPHeaders()) + }) + + t.Run("Should not set forwarded id header if not present for RunStream", func(t *testing.T) { + err := cdt.MiddlewareHandler.RunStream(ctx, &backend.RunStreamRequest{ + PluginContext: pluginContext, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.Empty(t, cdt.RunStreamReq.GetHTTPHeaders()) + }) + }) + }) + + t.Run("When signed in", func(t *testing.T) { cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ @@ -28,47 +85,154 @@ func TestForwardIDMiddleware(t *testing.T) { SignedInUser: &user.SignedInUser{IDToken: "some-token"}, }) - err := cdt.MiddlewareHandler.CallResource(ctx, &backend.CallResourceRequest{ - PluginContext: pluginContext, - }, nopCallResourceSender) - require.NoError(t, err) + t.Run("And requests are for a datasource", func(t *testing.T) { + pluginContext := backend.PluginContext{ + DataSourceInstanceSettings: &backend.DataSourceInstanceSettings{}, + } - require.Equal(t, "some-token", cdt.CallResourceReq.Headers[forwardIDHeaderName][0]) - }) + t.Run("Should set forwarded id header if present for QueryData", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.QueryData(ctx, &backend.QueryDataRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.QueryDataReq.GetHTTPHeader(forwardIDHeaderName)) + }) - t.Run("Should not set forwarded id header if not present", func(t *testing.T) { - cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + t.Run("Should set forwarded id header if present for CallResource", func(t *testing.T) { + err := cdt.MiddlewareHandler.CallResource(ctx, &backend.CallResourceRequest{ + PluginContext: pluginContext, + }, nopCallResourceSender) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.CallResourceReq.GetHTTPHeader(forwardIDHeaderName)) + }) - ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ - Context: &web.Context{Req: &http.Request{}}, - SignedInUser: &user.SignedInUser{}, + t.Run("Should set forwarded id header if present for CheckHealth", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.CheckHealth(ctx, &backend.CheckHealthRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.CheckHealthReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header if present for SubscribeStream", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.SubscribeStream(ctx, &backend.SubscribeStreamRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.SubscribeStreamReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header if present for PublishStream", func(t *testing.T) { + _, err := cdt.MiddlewareHandler.PublishStream(ctx, &backend.PublishStreamRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.PublishStreamReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header if present for RunStream", func(t *testing.T) { + err := cdt.MiddlewareHandler.RunStream(ctx, &backend.RunStreamRequest{ + PluginContext: pluginContext, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.RunStreamReq.GetHTTPHeader(forwardIDHeaderName)) + }) }) - err := cdt.MiddlewareHandler.CallResource(ctx, &backend.CallResourceRequest{ - PluginContext: pluginContext, - }, nopCallResourceSender) - require.NoError(t, err) + t.Run("And requests are for an app", func(t *testing.T) { + pluginContext := backend.PluginContext{ + AppInstanceSettings: &backend.AppInstanceSettings{}, + } - require.Len(t, cdt.CallResourceReq.Headers[forwardIDHeaderName], 0) - }) + t.Run("Should set forwarded id header to app plugin if present for QueryData", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) - pluginContext = backend.PluginContext{ - AppInstanceSettings: &backend.AppInstanceSettings{}, - } + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + }) - t.Run("Should set forwarded id header to app plugin if present", func(t *testing.T) { - cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + _, err := cdt.MiddlewareHandler.QueryData(ctx, &backend.QueryDataRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.QueryDataReq.GetHTTPHeader(forwardIDHeaderName)) + }) - ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ - Context: &web.Context{Req: &http.Request{}}, - SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + t.Run("Should set forwarded id header to app plugin if present for CallResource", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + }) + + err := cdt.MiddlewareHandler.CallResource(ctx, &backend.CallResourceRequest{ + PluginContext: pluginContext, + }, nopCallResourceSender) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.CallResourceReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header to app plugin if present for CheckHealth", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + }) + + _, err := cdt.MiddlewareHandler.CheckHealth(ctx, &backend.CheckHealthRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.CheckHealthReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header to app plugin if present for SubscribeStream", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + }) + + _, err := cdt.MiddlewareHandler.SubscribeStream(ctx, &backend.SubscribeStreamRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.SubscribeStreamReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header to app plugin if present for PublishStream", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + }) + + _, err := cdt.MiddlewareHandler.PublishStream(ctx, &backend.PublishStreamRequest{ + PluginContext: pluginContext, + }) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.PublishStreamReq.GetHTTPHeader(forwardIDHeaderName)) + }) + + t.Run("Should set forwarded id header to app plugin if present for RunStream", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, handlertest.WithMiddlewares(NewForwardIDMiddleware())) + + ctx := context.WithValue(context.Background(), ctxkey.Key{}, &contextmodel.ReqContext{ + Context: &web.Context{Req: &http.Request{}}, + SignedInUser: &user.SignedInUser{IDToken: "some-token"}, + }) + + err := cdt.MiddlewareHandler.RunStream(ctx, &backend.RunStreamRequest{ + PluginContext: pluginContext, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.Equal(t, "some-token", cdt.RunStreamReq.GetHTTPHeader(forwardIDHeaderName)) + }) }) - - err := cdt.MiddlewareHandler.CallResource(ctx, &backend.CallResourceRequest{ - PluginContext: pluginContext, - }, nopCallResourceSender) - require.NoError(t, err) - - require.Equal(t, "some-token", cdt.CallResourceReq.Headers[forwardIDHeaderName][0]) }) } diff --git a/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware.go index 035ae7f48cb..165c6abc417 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware.go @@ -65,3 +65,30 @@ func (m *TracingHeaderMiddleware) CheckHealth(ctx context.Context, req *backend. m.applyHeaders(ctx, req) return m.BaseHandler.CheckHealth(ctx, req) } + +func (m *TracingHeaderMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if req == nil { + return m.BaseHandler.SubscribeStream(ctx, req) + } + + m.applyHeaders(ctx, req) + return m.BaseHandler.SubscribeStream(ctx, req) +} + +func (m *TracingHeaderMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if req == nil { + return m.BaseHandler.PublishStream(ctx, req) + } + + m.applyHeaders(ctx, req) + return m.BaseHandler.PublishStream(ctx, req) +} + +func (m *TracingHeaderMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if req == nil { + return m.BaseHandler.RunStream(ctx, req, sender) + } + + m.applyHeaders(ctx, req) + return m.BaseHandler.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware_test.go index 7b1df5a8065..60435b2b557 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/tracing_header_middleware_test.go @@ -166,5 +166,77 @@ func TestTracingHeaderMiddleware(t *testing.T) { require.Equal(t, `d26e337d-cb53-481a-9212-0112537b3c1a`, cdt.CheckHealthReq.GetHTTPHeader(`X-Query-Group-Id`)) require.Equal(t, `true`, cdt.CheckHealthReq.GetHTTPHeader(`X-Grafana-From-Expr`)) }) + + t.Run("tracing headers are set for subscribe stream", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, + WithReqContext(req, &user.SignedInUser{ + IsAnonymous: true, + Login: "anonymous"}, + ), + handlertest.WithMiddlewares(NewTracingHeaderMiddleware()), + ) + + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + + require.Len(t, cdt.SubscribeStreamReq.GetHTTPHeaders(), 6) + require.Equal(t, `lN53lOcVk`, cdt.SubscribeStreamReq.GetHTTPHeader(`X-Dashboard-Uid`)) + require.Equal(t, `aIyC_OcVz`, cdt.SubscribeStreamReq.GetHTTPHeader(`X-Datasource-Uid`)) + require.Equal(t, `1`, cdt.SubscribeStreamReq.GetHTTPHeader(`X-Grafana-Org-Id`)) + require.Equal(t, `2`, cdt.SubscribeStreamReq.GetHTTPHeader(`X-Panel-Id`)) + require.Equal(t, `d26e337d-cb53-481a-9212-0112537b3c1a`, cdt.SubscribeStreamReq.GetHTTPHeader(`X-Query-Group-Id`)) + require.Equal(t, `true`, cdt.SubscribeStreamReq.GetHTTPHeader(`X-Grafana-From-Expr`)) + }) + + t.Run("tracing headers are set for publish stream", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, + WithReqContext(req, &user.SignedInUser{ + IsAnonymous: true, + Login: "anonymous"}, + ), + handlertest.WithMiddlewares(NewTracingHeaderMiddleware()), + ) + + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + + require.Len(t, cdt.PublishStreamReq.GetHTTPHeaders(), 6) + require.Equal(t, `lN53lOcVk`, cdt.PublishStreamReq.GetHTTPHeader(`X-Dashboard-Uid`)) + require.Equal(t, `aIyC_OcVz`, cdt.PublishStreamReq.GetHTTPHeader(`X-Datasource-Uid`)) + require.Equal(t, `1`, cdt.PublishStreamReq.GetHTTPHeader(`X-Grafana-Org-Id`)) + require.Equal(t, `2`, cdt.PublishStreamReq.GetHTTPHeader(`X-Panel-Id`)) + require.Equal(t, `d26e337d-cb53-481a-9212-0112537b3c1a`, cdt.PublishStreamReq.GetHTTPHeader(`X-Query-Group-Id`)) + require.Equal(t, `true`, cdt.PublishStreamReq.GetHTTPHeader(`X-Grafana-From-Expr`)) + }) + + t.Run("tracing headers are set for run stream", func(t *testing.T) { + cdt := handlertest.NewHandlerMiddlewareTest(t, + WithReqContext(req, &user.SignedInUser{ + IsAnonymous: true, + Login: "anonymous"}, + ), + handlertest.WithMiddlewares(NewTracingHeaderMiddleware()), + ) + + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }, &backend.StreamSender{}) + require.NoError(t, err) + + require.Len(t, cdt.RunStreamReq.GetHTTPHeaders(), 6) + require.Equal(t, `lN53lOcVk`, cdt.RunStreamReq.GetHTTPHeader(`X-Dashboard-Uid`)) + require.Equal(t, `aIyC_OcVz`, cdt.RunStreamReq.GetHTTPHeader(`X-Datasource-Uid`)) + require.Equal(t, `1`, cdt.RunStreamReq.GetHTTPHeader(`X-Grafana-Org-Id`)) + require.Equal(t, `2`, cdt.RunStreamReq.GetHTTPHeader(`X-Panel-Id`)) + require.Equal(t, `d26e337d-cb53-481a-9212-0112537b3c1a`, cdt.RunStreamReq.GetHTTPHeader(`X-Query-Group-Id`)) + require.Equal(t, `true`, cdt.RunStreamReq.GetHTTPHeader(`X-Grafana-From-Expr`)) + }) }) } diff --git a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go index 89b4fa4af25..36370d45648 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go +++ b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware.go @@ -66,3 +66,33 @@ func (m *UserHeaderMiddleware) CheckHealth(ctx context.Context, req *backend.Che return m.BaseHandler.CheckHealth(ctx, req) } + +func (m *UserHeaderMiddleware) SubscribeStream(ctx context.Context, req *backend.SubscribeStreamRequest) (*backend.SubscribeStreamResponse, error) { + if req == nil { + return m.BaseHandler.SubscribeStream(ctx, req) + } + + m.applyUserHeader(ctx, req) + + return m.BaseHandler.SubscribeStream(ctx, req) +} + +func (m *UserHeaderMiddleware) PublishStream(ctx context.Context, req *backend.PublishStreamRequest) (*backend.PublishStreamResponse, error) { + if req == nil { + return m.BaseHandler.PublishStream(ctx, req) + } + + m.applyUserHeader(ctx, req) + + return m.BaseHandler.PublishStream(ctx, req) +} + +func (m *UserHeaderMiddleware) RunStream(ctx context.Context, req *backend.RunStreamRequest, sender *backend.StreamSender) error { + if req == nil { + return m.BaseHandler.RunStream(ctx, req, sender) + } + + m.applyUserHeader(ctx, req) + + return m.BaseHandler.RunStream(ctx, req, sender) +} diff --git a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go index 327705bd1f6..61140f44481 100644 --- a/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go +++ b/pkg/services/pluginsintegration/clientmiddleware/user_header_middleware_test.go @@ -12,7 +12,7 @@ import ( ) func TestUserHeaderMiddleware(t *testing.T) { - t.Run("When anononymous user in reqContext", func(t *testing.T) { + t.Run("When anonymous user in reqContext", func(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "/some/thing", nil) require.NoError(t, err) @@ -58,6 +58,36 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.CheckHealthReq) require.Empty(t, cdt.CheckHealthReq.Headers) }) + + t.Run("Should not forward user header when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Empty(t, cdt.SubscribeStreamReq.Headers) + }) + + t.Run("Should not forward user header when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Empty(t, cdt.PublishStreamReq.Headers) + }) + + t.Run("Should not forward user header when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Empty(t, cdt.RunStreamReq.Headers) + }) }) t.Run("And requests are for an app", func(t *testing.T) { @@ -102,6 +132,36 @@ func TestUserHeaderMiddleware(t *testing.T) { require.NotNil(t, cdt.CheckHealthReq) require.Empty(t, cdt.CheckHealthReq.Headers) }) + + t.Run("Should not forward user header when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Empty(t, cdt.SubscribeStreamReq.Headers) + }) + + t.Run("Should not forward user header when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Empty(t, cdt.PublishStreamReq.Headers) + }) + + t.Run("Should not forward user header when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Empty(t, cdt.RunStreamReq.Headers) + }) }) }) @@ -153,6 +213,39 @@ func TestUserHeaderMiddleware(t *testing.T) { require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName)) }) + + t.Run("Should forward user header when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Len(t, cdt.SubscribeStreamReq.Headers, 1) + require.Equal(t, "admin", cdt.SubscribeStreamReq.GetHTTPHeader(proxyutil.UserHeaderName)) + }) + + t.Run("Should forward user header when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Len(t, cdt.PublishStreamReq.Headers, 1) + require.Equal(t, "admin", cdt.PublishStreamReq.GetHTTPHeader(proxyutil.UserHeaderName)) + }) + + t.Run("Should forward user header when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Len(t, cdt.RunStreamReq.Headers, 1) + require.Equal(t, "admin", cdt.RunStreamReq.GetHTTPHeader(proxyutil.UserHeaderName)) + }) }) t.Run("And requests are for an app", func(t *testing.T) { @@ -199,6 +292,39 @@ func TestUserHeaderMiddleware(t *testing.T) { require.Len(t, cdt.CheckHealthReq.Headers, 1) require.Equal(t, "admin", cdt.CheckHealthReq.GetHTTPHeader(proxyutil.UserHeaderName)) }) + + t.Run("Should forward user header when calling SubscribeStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.SubscribeStream(req.Context(), &backend.SubscribeStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.SubscribeStreamReq) + require.Len(t, cdt.SubscribeStreamReq.Headers, 1) + require.Equal(t, "admin", cdt.SubscribeStreamReq.GetHTTPHeader(proxyutil.UserHeaderName)) + }) + + t.Run("Should forward user header when calling PublishStream", func(t *testing.T) { + _, err = cdt.MiddlewareHandler.PublishStream(req.Context(), &backend.PublishStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }) + require.NoError(t, err) + require.NotNil(t, cdt.PublishStreamReq) + require.Len(t, cdt.PublishStreamReq.Headers, 1) + require.Equal(t, "admin", cdt.PublishStreamReq.GetHTTPHeader(proxyutil.UserHeaderName)) + }) + + t.Run("Should forward user header when calling RunStream", func(t *testing.T) { + err = cdt.MiddlewareHandler.RunStream(req.Context(), &backend.RunStreamRequest{ + PluginContext: pluginCtx, + Headers: map[string]string{}, + }, &backend.StreamSender{}) + require.NoError(t, err) + require.NotNil(t, cdt.RunStreamReq) + require.Len(t, cdt.RunStreamReq.Headers, 1) + require.Equal(t, "admin", cdt.RunStreamReq.GetHTTPHeader(proxyutil.UserHeaderName)) + }) }) }) }