AzureMonitor: Fix Azure token provider national clouds (#34615) (#34617)

* Fix AAD authority for sovereign clouds

* Update Azure SDK with scopes fix

* Credential initialization in cache

(cherry picked from commit a337f70469)

Co-authored-by: Sergey Kostrukov <sekost@microsoft.com>
This commit is contained in:
Grot (@grafanabot)
2021-05-25 07:46:03 +01:00
committed by GitHub
parent 44204a745c
commit d678865934
5 changed files with 260 additions and 89 deletions
+177 -6
View File
@@ -14,7 +14,9 @@ import (
type fakeCredential struct {
key string
initCalledTimes int
calledTimes int
initFunc func() error
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
@@ -22,6 +24,19 @@ func (c *fakeCredential) GetCacheKey() string {
return c.key
}
func (c *fakeCredential) Reset() {
c.initCalledTimes = 0
c.calledTimes = 0
}
func (c *fakeCredential) Init() error {
c.initCalledTimes = c.initCalledTimes + 1
if c.initFunc != nil {
return c.initFunc()
}
return nil
}
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
@@ -103,12 +118,168 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
})
}
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("when credential init returns error", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
return errors.New("unable to initialize")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
err := cacheEntry.ensureInitialized()
assert.Error(t, err)
})
t.Run("should call init again each time and return error", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
return errors.New("unable to initialize")
}
return nil
},
}
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
assert.Equal(t, 2, credential.initCalledTimes)
})
})
t.Run("when credential init panics", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
panic(errors.New("unable to initialize"))
},
}
t.Run("should call credential init again each time", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
panic(errors.New("unable to initialize"))
}
return nil
},
}
t.Run("should call credential init again only while it panics", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
assert.Equal(t, 2, credential.initCalledTimes)
})
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when credential returns error", func(t *testing.T) {
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
@@ -130,7 +301,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.calledTimes = 0
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
@@ -152,7 +323,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
})
t.Run("when credential returns error only once", func(t *testing.T) {
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
@@ -191,7 +362,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
})
t.Run("when credential panics", func(t *testing.T) {
t.Run("when credential getAccessToken panics", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
@@ -199,7 +370,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.calledTimes = 0
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
@@ -232,7 +403,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
})
})
t.Run("when credential panics only once", func(t *testing.T) {
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {