Files
grafana/apps/provisioning/pkg/auth/round_tripper_test.go
2025-09-10 09:04:10 +03:00

107 lines
3.2 KiB
Go

package auth
import (
"context"
"io"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/grafana/authlib/authn"
"github.com/grafana/grafana/apps/provisioning/pkg/apis/provisioning/v0alpha1"
"github.com/stretchr/testify/require"
)
type fakeExchanger struct {
resp *authn.TokenExchangeResponse
err error
gotReq *authn.TokenExchangeRequest
}
func (f *fakeExchanger) Exchange(_ context.Context, req authn.TokenExchangeRequest) (*authn.TokenExchangeResponse, error) {
f.gotReq = &req
return f.resp, f.err
}
// roundTripperFunc allows building a stub transport inline
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestRoundTripper_SetsAccessTokenHeader(t *testing.T) {
tr := NewRoundTripper(&fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "abc123"}}, roundTripperFunc(func(r *http.Request) (*http.Response, error) {
got := r.Header.Get("X-Access-Token")
if got != "Bearer abc123" {
t.Fatalf("expected X-Access-Token header 'Bearer abc123', got %q", got)
}
// Return a minimal response; body must be non-nil per http.RoundTripper contract
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
}), "example-audience")
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example", nil)
resp, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// drain and close body
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
}
func TestRoundTripper_PropagatesExchangeError(t *testing.T) {
tr := NewRoundTripper(&fakeExchanger{err: io.EOF}, roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
t.Fatal("transport should not be called on exchange error")
return nil, nil
}), "example-audience")
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example", nil)
resp, err := tr.RoundTrip(req)
if err == nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
t.Fatalf("expected error, got nil")
}
}
func TestRoundTripper_AudiencesAndNamespace(t *testing.T) {
tests := []struct {
name string
audience string
wantAudiences []string
}{
{
name: "adds group when custom audience",
audience: "example-audience",
wantAudiences: []string{"example-audience", v0alpha1.GROUP},
},
{
name: "no duplicate when group audience",
audience: v0alpha1.GROUP,
wantAudiences: []string{v0alpha1.GROUP},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fx := &fakeExchanger{resp: &authn.TokenExchangeResponse{Token: "abc123"}}
tr := NewRoundTripper(fx, roundTripperFunc(func(_ *http.Request) (*http.Response, error) {
rr := httptest.NewRecorder()
rr.WriteHeader(http.StatusOK)
return rr.Result(), nil
}), tt.audience)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example", nil)
resp, err := tr.RoundTrip(req)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
require.NotNil(t, fx.gotReq)
require.True(t, reflect.DeepEqual(fx.gotReq.Audiences, tt.wantAudiences))
})
}
}