* Fix AAD authority for sovereign clouds
* Update Azure SDK with scopes fix
* Credential initialization in cache
(cherry picked from commit a337f70469)
Co-authored-by: Sergey Kostrukov <sekost@microsoft.com>
This commit is contained in:
committed by
GitHub
parent
44204a745c
commit
d678865934
@@ -14,7 +14,9 @@ import (
|
||||
|
||||
type fakeCredential struct {
|
||||
key string
|
||||
initCalledTimes int
|
||||
calledTimes int
|
||||
initFunc func() error
|
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
|
||||
}
|
||||
|
||||
@@ -22,6 +24,19 @@ func (c *fakeCredential) GetCacheKey() string {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Reset() {
|
||||
c.initCalledTimes = 0
|
||||
c.calledTimes = 0
|
||||
}
|
||||
|
||||
func (c *fakeCredential) Init() error {
|
||||
c.initCalledTimes = c.initCalledTimes + 1
|
||||
if c.initFunc != nil {
|
||||
return c.initFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
c.calledTimes = c.calledTimes + 1
|
||||
if c.getAccessTokenFunc != nil {
|
||||
@@ -103,12 +118,168 @@ func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
|
||||
t.Run("when credential init returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
return errors.New("unable to initialize")
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should return error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
err := cacheEntry.ensureInitialized()
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
return errors.New("unable to initialize")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.Error(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
panic(errors.New("unable to initialize"))
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again each time", func(t *testing.T) {
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential init panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
initFunc: func() error {
|
||||
times = times + 1
|
||||
if times == 1 {
|
||||
panic(errors.New("unable to initialize"))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("should call credential init again only while it panics", func(t *testing.T) {
|
||||
cacheEntry := &credentialCacheEntry{
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.NotNil(t, recover(), "credential expected to panic")
|
||||
}()
|
||||
_ = cacheEntry.ensureInitialized()
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
assert.Nil(t, recover(), "credential not expected to panic")
|
||||
}()
|
||||
err = cacheEntry.ensureInitialized()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
scopes := []string{"Scope1"}
|
||||
|
||||
t.Run("when credential returns error", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
|
||||
@@ -130,7 +301,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) {
|
||||
credential.calledTimes = 0
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
@@ -152,7 +323,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential returns error only once", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
@@ -191,7 +362,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential panics", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken panics", func(t *testing.T) {
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
panic(errors.New("unable to get access token"))
|
||||
@@ -199,7 +370,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) {
|
||||
credential.calledTimes = 0
|
||||
credential.Reset()
|
||||
|
||||
cacheEntry := &scopesCacheEntry{
|
||||
credential: credential,
|
||||
@@ -232,7 +403,7 @@ func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("when credential panics only once", func(t *testing.T) {
|
||||
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
|
||||
var times = 0
|
||||
credential := &fakeCredential{
|
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
|
||||
|
||||
Reference in New Issue
Block a user