Prometheus: Implement Streaming JSON Parser (#48477)
use `prometheusStreamingJSONParser` feature toggle to enable
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
)
|
||||
|
||||
type ProviderCache struct {
|
||||
provider promClientProvider
|
||||
cache *lru.Cache
|
||||
}
|
||||
|
||||
type promClientProvider interface {
|
||||
GetClient(map[string]string) (*Client, error)
|
||||
}
|
||||
|
||||
func NewProviderCache(p promClientProvider) (*ProviderCache, error) {
|
||||
cache, err := lru.New(500)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ProviderCache{
|
||||
provider: p,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *ProviderCache) GetClient(headers map[string]string) (*Client, error) {
|
||||
key := c.key(headers)
|
||||
if client, ok := c.cache.Get(key); ok {
|
||||
return client.(*Client), nil
|
||||
}
|
||||
|
||||
client, err := c.provider.GetClient(headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.cache.Add(key, client)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *ProviderCache) key(headers map[string]string) string {
|
||||
vals := make([]string, len(headers))
|
||||
var i int
|
||||
for _, v := range headers {
|
||||
vals[i] = v
|
||||
i++
|
||||
}
|
||||
sort.Strings(vals)
|
||||
return strings.Join(vals, "")
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCache_GetClient(t *testing.T) {
|
||||
t.Run("it caches the client for a set of auth headers", func(t *testing.T) {
|
||||
tc := setupCacheContext()
|
||||
|
||||
c, err := tc.providerCache.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
c2, err := tc.providerCache.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, c, c2)
|
||||
require.Equal(t, 1, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it returns different clients when the headers differ", func(t *testing.T) {
|
||||
tc := setupCacheContext()
|
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
|
||||
|
||||
c, err := tc.providerCache.GetClient(h1)
|
||||
require.Nil(t, err)
|
||||
|
||||
c2, err := tc.providerCache.GetClient(h2)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.NotEqual(t, c, c2)
|
||||
require.Equal(t, 2, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it returns from the cache when headers are the same", func(t *testing.T) {
|
||||
tc := setupCacheContext()
|
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
|
||||
c, err := tc.providerCache.GetClient(h1)
|
||||
require.Nil(t, err)
|
||||
|
||||
c2, err := tc.providerCache.GetClient(h2)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, c, c2)
|
||||
require.Equal(t, 1, tc.clientProvider.numCalls)
|
||||
})
|
||||
|
||||
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) {
|
||||
tc := setupCacheContext()
|
||||
tc.clientProvider.errors <- errors.New("something bad")
|
||||
|
||||
_, err := tc.providerCache.GetClient(headers)
|
||||
require.EqualError(t, err, "something bad")
|
||||
|
||||
c, err := tc.providerCache.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.NotNil(t, c)
|
||||
require.Equal(t, 2, tc.clientProvider.numCalls)
|
||||
})
|
||||
}
|
||||
|
||||
type cacheTestContext struct {
|
||||
providerCache *client.ProviderCache
|
||||
clientProvider *fakeClientProvider
|
||||
}
|
||||
|
||||
func setupCacheContext() *cacheTestContext {
|
||||
fp := newFakePromClientProvider()
|
||||
p, err := client.NewProviderCache(fp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &cacheTestContext{
|
||||
providerCache: p,
|
||||
clientProvider: fp,
|
||||
}
|
||||
}
|
||||
|
||||
func newFakePromClientProvider() *fakeClientProvider {
|
||||
return &fakeClientProvider{
|
||||
errors: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClientProvider struct {
|
||||
headers map[string]string
|
||||
numCalls int
|
||||
errors chan error
|
||||
}
|
||||
|
||||
func (p *fakeClientProvider) GetClient(h map[string]string) (*client.Client, error) {
|
||||
p.headers = h
|
||||
p.numCalls++
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-p.errors:
|
||||
default:
|
||||
}
|
||||
|
||||
var config []string
|
||||
for _, v := range h {
|
||||
config = append(config, v)
|
||||
}
|
||||
sort.Strings(config) //because map
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Header: http.Header{},
|
||||
Body: ioutil.NopCloser(strings.NewReader(strings.Join(config, ","))),
|
||||
}
|
||||
c := &fakeClient{res: res}
|
||||
return client.NewClient(c, "GET", "http://localhost:9090/"), err
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (c *fakeClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return c.res, nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/models"
|
||||
)
|
||||
|
||||
type doer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
doer doer
|
||||
method string
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
func NewClient(d doer, method, baseUrl string) *Client {
|
||||
return &Client{doer: d, method: method, baseUrl: baseUrl}
|
||||
}
|
||||
|
||||
func (c *Client) QueryRange(ctx context.Context, q *models.Query) (*http.Response, error) {
|
||||
u, err := url.ParseRequestURI(c.baseUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.Path = path.Join(u.Path, "api/v1/query_range")
|
||||
|
||||
qs := u.Query()
|
||||
qs.Set("query", q.Expr)
|
||||
tr := q.TimeRange()
|
||||
qs.Set("start", formatTime(tr.Start))
|
||||
qs.Set("end", formatTime(tr.End))
|
||||
qs.Set("step", strconv.FormatFloat(tr.Step.Seconds(), 'f', -1, 64))
|
||||
|
||||
return c.fetch(ctx, u, qs)
|
||||
}
|
||||
|
||||
func (c *Client) QueryInstant(ctx context.Context, q *models.Query) (*http.Response, error) {
|
||||
u, err := url.ParseRequestURI(c.baseUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.Path = path.Join(u.Path, "api/v1/query")
|
||||
|
||||
qs := u.Query()
|
||||
qs.Set("query", q.Expr)
|
||||
tr := q.TimeRange()
|
||||
if !tr.End.IsZero() {
|
||||
qs.Set("time", formatTime(tr.End))
|
||||
}
|
||||
|
||||
return c.fetch(ctx, u, qs)
|
||||
}
|
||||
|
||||
func (c *Client) QueryExemplars(ctx context.Context, q *models.Query) (*http.Response, error) {
|
||||
u, err := url.ParseRequestURI(c.baseUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.Path = path.Join(u.Path, "api/v1/query_exemplars")
|
||||
|
||||
qs := u.Query()
|
||||
tr := q.TimeRange()
|
||||
qs.Set("query", q.Expr)
|
||||
qs.Set("start", formatTime(tr.Start))
|
||||
qs.Set("end", formatTime(tr.End))
|
||||
|
||||
return c.fetch(ctx, u, qs)
|
||||
}
|
||||
|
||||
func (c *Client) fetch(ctx context.Context, u *url.URL, qs url.Values) (*http.Response, error) {
|
||||
if strings.ToUpper(c.method) == http.MethodGet {
|
||||
u.RawQuery = qs.Encode()
|
||||
}
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, c.method, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.ToUpper(c.method) == http.MethodPost {
|
||||
r.Body = ioutil.NopCloser(strings.NewReader(qs.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
|
||||
return c.doer.Do(r)
|
||||
}
|
||||
|
||||
func formatTime(t time.Time) string {
|
||||
return strconv.FormatFloat(float64(t.Unix())+float64(t.Nanosecond())/1e9, 'f', -1, 64)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/middleware"
|
||||
"github.com/grafana/grafana/pkg/util/maputil"
|
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
settings backend.DataSourceInstanceSettings
|
||||
jsonData map[string]interface{}
|
||||
httpMethod string
|
||||
clientProvider httpclient.Provider
|
||||
cfg *setting.Cfg
|
||||
features featuremgmt.FeatureToggles
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func NewProvider(
|
||||
settings backend.DataSourceInstanceSettings,
|
||||
jsonData map[string]interface{},
|
||||
clientProvider httpclient.Provider,
|
||||
cfg *setting.Cfg,
|
||||
features featuremgmt.FeatureToggles,
|
||||
log log.Logger,
|
||||
) *Provider {
|
||||
httpMethod, _ := maputil.GetStringOptional(jsonData, "httpMethod")
|
||||
return &Provider{
|
||||
settings: settings,
|
||||
jsonData: jsonData,
|
||||
httpMethod: httpMethod,
|
||||
clientProvider: clientProvider,
|
||||
cfg: cfg,
|
||||
features: features,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) GetClient(headers map[string]string) (*Client, error) {
|
||||
opts, err := p.settings.HTTPClientOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.Middlewares = p.middlewares()
|
||||
opts.Headers = reqHeaders(headers)
|
||||
|
||||
// Set SigV4 service namespace
|
||||
if opts.SigV4 != nil {
|
||||
opts.SigV4.Service = "aps"
|
||||
}
|
||||
|
||||
// Azure authentication
|
||||
err = p.configureAzureAuthentication(&opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient, err := p.clientProvider.New(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClient(httpClient, p.httpMethod, p.settings.URL), nil
|
||||
}
|
||||
|
||||
func (p *Provider) middlewares() []sdkhttpclient.Middleware {
|
||||
middlewares := []sdkhttpclient.Middleware{
|
||||
middleware.CustomQueryParameters(p.log),
|
||||
sdkhttpclient.CustomHeadersMiddleware(),
|
||||
}
|
||||
return middlewares
|
||||
}
|
||||
|
||||
func reqHeaders(headers map[string]string) map[string]string {
|
||||
// copy to avoid changing the original map
|
||||
h := make(map[string]string, len(headers))
|
||||
for k, v := range headers {
|
||||
h[k] = v
|
||||
}
|
||||
return h
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azcredentials"
|
||||
"github.com/grafana/grafana-azure-sdk-go/azhttpclient"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/util/maputil"
|
||||
)
|
||||
|
||||
func (p *Provider) configureAzureAuthentication(opts *sdkhttpclient.Options) error {
|
||||
// Azure authentication is experimental (#35857)
|
||||
if !p.features.IsEnabled(featuremgmt.FlagPrometheusAzureAuth) {
|
||||
return nil
|
||||
}
|
||||
|
||||
credentials, err := azcredentials.FromDatasourceData(p.jsonData, p.settings.DecryptedSecureJSONData)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid Azure credentials: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if credentials != nil {
|
||||
resourceIdStr, err := maputil.GetStringOptional(p.jsonData, "azureEndpointResourceId")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if resourceIdStr == "" {
|
||||
err := fmt.Errorf("endpoint resource ID (audience) not provided")
|
||||
return err
|
||||
}
|
||||
|
||||
resourceId, err := url.Parse(resourceIdStr)
|
||||
if err != nil || resourceId.Scheme == "" || resourceId.Host == "" {
|
||||
err := fmt.Errorf("endpoint resource ID (audience) '%s' invalid", resourceIdStr)
|
||||
return err
|
||||
}
|
||||
|
||||
resourceId.Path = path.Join(resourceId.Path, ".default")
|
||||
scopes := []string{resourceId.String()}
|
||||
|
||||
azhttpclient.AddAzureAuthentication(opts, p.cfg.Azure, credentials, scopes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfigureAzureAuthentication(t *testing.T) {
|
||||
cfg := &setting.Cfg{}
|
||||
settings := backend.DataSourceInstanceSettings{}
|
||||
|
||||
t.Run("given feature flag enabled", func(t *testing.T) {
|
||||
features := featuremgmt.WithFeatures(featuremgmt.FlagPrometheusAzureAuth)
|
||||
|
||||
t.Run("should set Azure middleware when JsonData contains valid credentials", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
"azureCredentials": map[string]interface{}{
|
||||
"authType": "msi",
|
||||
},
|
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, opts.Middlewares)
|
||||
assert.Len(t, opts.Middlewares, 1)
|
||||
})
|
||||
|
||||
t.Run("should not set Azure middleware when JsonData doesn't contain valid credentials", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
|
||||
})
|
||||
|
||||
t.Run("should return error when JsonData contains invalid credentials", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
"azureCredentials": "invalid",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should set Azure middleware when JsonData contains credentials and valid audience", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
"azureCredentials": map[string]interface{}{
|
||||
"authType": "msi",
|
||||
},
|
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, opts.Middlewares)
|
||||
assert.Len(t, opts.Middlewares, 1)
|
||||
})
|
||||
|
||||
t.Run("should not set Azure middleware when JsonData doesn't contain credentials", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
if opts.Middlewares != nil {
|
||||
assert.Len(t, opts.Middlewares, 0)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return error when JsonData contains invalid audience", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
"azureCredentials": map[string]interface{}{
|
||||
"authType": "msi",
|
||||
},
|
||||
"azureEndpointResourceId": "invalid",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("given feature flag not enabled", func(t *testing.T) {
|
||||
features := featuremgmt.WithFeatures()
|
||||
|
||||
t.Run("should not set Azure Credentials even when JsonData contains credentials", func(t *testing.T) {
|
||||
jsonData := map[string]interface{}{
|
||||
"httpMethod": "POST",
|
||||
"azureCredentials": map[string]interface{}{
|
||||
"authType": "msi",
|
||||
},
|
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5",
|
||||
}
|
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil)
|
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}}
|
||||
|
||||
err := p.configureAzureAuthentication(opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
if opts.Middlewares != nil {
|
||||
assert.Len(t, opts.Middlewares, 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
"github.com/grafana/grafana/pkg/infra/httpclient"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/client"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var headers = map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
|
||||
|
||||
func TestGetClient(t *testing.T) {
|
||||
t.Run("it sets the SigV4 service if it exists", func(t *testing.T) {
|
||||
tc := setup(`{"sigV4Auth":true}`)
|
||||
|
||||
setting.SigV4AuthEnabled = true
|
||||
defer func() { setting.SigV4AuthEnabled = false }()
|
||||
|
||||
_, err := tc.clientProvider.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, "aps", tc.httpProvider.opts.SigV4.Service)
|
||||
})
|
||||
|
||||
t.Run("it always uses the custom params and custom headers middlewares", func(t *testing.T) {
|
||||
tc := setup()
|
||||
|
||||
_, err := tc.clientProvider.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 2)
|
||||
require.Contains(t, tc.httpProvider.middlewares(), "prom-custom-query-parameters")
|
||||
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders")
|
||||
})
|
||||
|
||||
t.Run("extra headers", func(t *testing.T) {
|
||||
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) {
|
||||
tc := setup(`{"oauthPassThru":true}`)
|
||||
_, err := tc.clientProvider.GetClient(headers)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, headers, tc.httpProvider.opts.Headers)
|
||||
})
|
||||
|
||||
t.Run("it sets all headers", func(t *testing.T) {
|
||||
withNonAuth := map[string]string{"X-Not-Auth": "stuff"}
|
||||
|
||||
tc := setup(`{"oauthPassThru":true}`)
|
||||
_, err := tc.clientProvider.GetClient(withNonAuth)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, map[string]string{"X-Not-Auth": "stuff"}, tc.httpProvider.opts.Headers)
|
||||
})
|
||||
|
||||
t.Run("it does not error when headers are nil", func(t *testing.T) {
|
||||
tc := setup(`{"oauthPassThru":true}`)
|
||||
|
||||
_, err := tc.clientProvider.GetClient(nil)
|
||||
require.Nil(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func setup(jsonData ...string) *testContext {
|
||||
var rawData []byte
|
||||
if len(jsonData) > 0 {
|
||||
rawData = []byte(jsonData[0])
|
||||
}
|
||||
|
||||
var jd map[string]interface{}
|
||||
_ = json.Unmarshal(rawData, &jd)
|
||||
|
||||
cfg := &setting.Cfg{}
|
||||
settings := backend.DataSourceInstanceSettings{URL: "test-url", JSONData: rawData}
|
||||
features := featuremgmt.WithFeatures()
|
||||
hp := &fakeHttpClientProvider{}
|
||||
p := client.NewProvider(settings, jd, hp, cfg, features, nil)
|
||||
|
||||
return &testContext{
|
||||
httpProvider: hp,
|
||||
clientProvider: p,
|
||||
}
|
||||
}
|
||||
|
||||
type testContext struct {
|
||||
httpProvider *fakeHttpClientProvider
|
||||
clientProvider *client.Provider
|
||||
}
|
||||
|
||||
type fakeHttpClientProvider struct {
|
||||
httpclient.Provider
|
||||
|
||||
opts sdkhttpclient.Options
|
||||
}
|
||||
|
||||
func (p *fakeHttpClientProvider) New(opts ...sdkhttpclient.Options) (*http.Client, error) {
|
||||
p.opts = opts[0]
|
||||
return sdkhttpclient.New(opts[0])
|
||||
}
|
||||
|
||||
func (p *fakeHttpClientProvider) GetTransport(opts ...sdkhttpclient.Options) (http.RoundTripper, error) {
|
||||
p.opts = opts[0]
|
||||
return http.DefaultTransport, nil
|
||||
}
|
||||
|
||||
func (p *fakeHttpClientProvider) middlewares() []string {
|
||||
var middlewareNames []string
|
||||
for _, m := range p.opts.Middlewares {
|
||||
mw, ok := m.(sdkhttpclient.MiddlewareName)
|
||||
if !ok {
|
||||
panic("unexpected middleware type")
|
||||
}
|
||||
|
||||
middlewareNames = append(middlewareNames, mw.MiddlewareName())
|
||||
}
|
||||
return middlewareNames
|
||||
}
|
||||
Reference in New Issue
Block a user