From 982e095f85f74ee114f11842a4c2e0a64673ffec Mon Sep 17 00:00:00 2001 From: Daniel Lee Date: Fri, 14 Sep 2018 11:13:09 +0200 Subject: [PATCH] dsproxy: add mutex protection to the token caches --- pkg/api/pluginproxy/access_token_provider.go | 45 ++++++++++++-------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/pkg/api/pluginproxy/access_token_provider.go b/pkg/api/pluginproxy/access_token_provider.go index c590b0aee8c..0f07195cfcf 100644 --- a/pkg/api/pluginproxy/access_token_provider.go +++ b/pkg/api/pluginproxy/access_token_provider.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "strconv" + "sync" "time" "golang.org/x/oauth2" @@ -17,10 +18,24 @@ import ( ) var ( - tokenCache = map[string]*jwtToken{} - oauthJwtTokenCache = map[string]*oauth2.Token{} + tokenCache = tokenCacheType{ + cache: map[string]*jwtToken{}, + } + oauthJwtTokenCache = oauthJwtTokenCacheType{ + cache: map[string]*oauth2.Token{}, + } ) +type tokenCacheType struct { + cache map[string]*jwtToken + sync.Mutex +} + +type oauthJwtTokenCacheType struct { + cache map[string]*oauth2.Token + sync.Mutex +} + type accessTokenProvider struct { route *plugins.AppPluginRoute datasourceID int64 @@ -40,7 +55,9 @@ func newAccessTokenProvider(dsID int64, pluginRoute *plugins.AppPluginRoute) *ac } func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) { - if cachedToken, found := tokenCache[provider.getAccessTokenCacheKey()]; found { + tokenCache.Lock() + defer tokenCache.Unlock() + if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found { if cachedToken.ExpiresOn.After(time.Now().Add(time.Second * 10)) { logger.Info("Using token from cache") return cachedToken.AccessToken, nil @@ -79,7 +96,7 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string, expiresOnEpoch, _ := strconv.ParseInt(token.ExpiresOnString, 10, 64) token.ExpiresOn = time.Unix(expiresOnEpoch, 0) - tokenCache[provider.getAccessTokenCacheKey()] = &token + tokenCache.cache[provider.getAccessTokenCacheKey()] = &token logger.Info("Got new access token", "ExpiresOn", token.ExpiresOn) @@ -87,7 +104,9 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string, } func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) { - if cachedToken, found := oauthJwtTokenCache[provider.getAccessTokenCacheKey()]; found { + oauthJwtTokenCache.Lock() + defer oauthJwtTokenCache.Unlock() + if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found { if cachedToken.Expiry.After(time.Now().Add(time.Second * 10)) { logger.Info("Using token from cache") return cachedToken.AccessToken, nil @@ -127,7 +146,9 @@ func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data return "", err } - oauthJwtTokenCache[provider.getAccessTokenCacheKey()] = token + oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token + + logger.Info("Got new access token", "ExpiresOn", token.Expiry) return token.AccessToken, nil } @@ -139,21 +160,9 @@ var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, return nil, err } - // logger.Info("interpolatedVal", "token.AccessToken", token.AccessToken) - return token, nil } func (provider *accessTokenProvider) getAccessTokenCacheKey() string { return fmt.Sprintf("%v_%v_%v", provider.datasourceID, provider.route.Path, provider.route.Method) } - -//Export access token lookup -func GetAccessTokenFromCache(datasourceID int64, path string, method string) (string, error) { - key := fmt.Sprintf("%v_%v_%v", datasourceID, path, method) - if cachedToken, found := oauthJwtTokenCache[key]; found { - return cachedToken.AccessToken, nil - } else { - return "", fmt.Errorf("Key doesnt exist") - } -}