From 2e2b1cd9e4398b5c6f8dbe4bad25e5fb45805424 Mon Sep 17 00:00:00 2001 From: Mihai Doarna Date: Wed, 29 Nov 2023 18:02:04 +0200 Subject: [PATCH] Refactor SSOSettings to use types (#78675) * refactor SSOSettings to use types * test struct * refactor SSOSettings struct to use types * fix database tests * fix populateSSOSettings() to accept an SSOSettings param * fix all tests from the database layer * handle errors for converting to/from SSOSettings * add json tag on OAuthInfo fields * use continue instead of if/else * add the source field to SSOSettingsDTO conversion * remove omitempty from json tags in OAuthInfo struct --- pkg/login/social/azuread_oauth.go | 2 +- pkg/login/social/common.go | 4 +- pkg/login/social/commont_test.go | 4 +- pkg/login/social/generic_oauth.go | 2 +- pkg/login/social/github_oauth.go | 2 +- pkg/login/social/gitlab_oauth.go | 2 +- pkg/login/social/google_oauth.go | 2 +- pkg/login/social/grafana_com_oauth.go | 2 +- pkg/login/social/okta_oauth.go | 2 +- pkg/login/social/social.go | 66 +++--- pkg/services/ssosettings/api/api.go | 33 ++- pkg/services/ssosettings/database/database.go | 53 +++-- .../ssosettings/database/database_test.go | 218 +++++++++++------- pkg/services/ssosettings/models/models.go | 84 ++++--- pkg/services/ssosettings/ssosettings.go | 12 +- .../ssosettings/ssosettingsimpl/service.go | 50 ++-- .../ssosettingsimpl/service_test.go | 207 +++++++---------- .../ssosettingstests/service_mock.go | 30 +-- .../ssosettingstests/store_fake.go | 10 +- .../ssosettingstests/store_mock.go | 43 ++-- 20 files changed, 445 insertions(+), 383 deletions(-) diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/azuread_oauth.go index 8b395af484f..e392eadc874 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/azuread_oauth.go @@ -65,7 +65,7 @@ type keySetJWKS struct { } func NewAzureADProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager, cache remotecache.CacheStorage) (*SocialAzureAD, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/common.go b/pkg/login/social/common.go index 6a7eb3f50ff..ce7d3b6cd45 100644 --- a/pkg/login/social/common.go +++ b/pkg/login/social/common.go @@ -197,9 +197,9 @@ func convertIniSectionToMap(sec *ini.Section) map[string]any { return mappedSettings } -// createOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure +// CreateOAuthInfoFromKeyValues creates an OAuthInfo struct from a map[string]any using mapstructure // it puts all extra key values into OAuthInfo's Extra map -func createOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) { +func CreateOAuthInfoFromKeyValues(settingsKV map[string]any) (*OAuthInfo, error) { emptyStrToSliceDecodeHook := func(from reflect.Type, to reflect.Type, data any) (any, error) { if from.Kind() == reflect.String && to.Kind() == reflect.Slice { strData, ok := data.(string) diff --git a/pkg/login/social/commont_test.go b/pkg/login/social/commont_test.go index 7acb78b3473..39f8ebb8dd7 100644 --- a/pkg/login/social/commont_test.go +++ b/pkg/login/social/commont_test.go @@ -33,7 +33,7 @@ token_url = test_token_url api_url = test_api_url teams_url = test_teams_url allowed_domains = domain1.com -allowed_groups = +allowed_groups = team_ids = first, second allowed_organizations = org1, org2 tls_skip_verify_insecure = true @@ -96,7 +96,7 @@ signout_redirect_url = https://oauth.com/signout?post_logout_redirect_uri=https: } settingsKVs := convertIniSectionToMap(iniFile.Section("test")) - oauthInfo, err := createOAuthInfoFromKeyValues(settingsKVs) + oauthInfo, err := CreateOAuthInfoFromKeyValues(settingsKVs) require.NoError(t, err) require.Equal(t, expectedOAuthInfo, oauthInfo) diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/generic_oauth.go index 1dc7a6c1824..d4de940255f 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/generic_oauth.go @@ -37,7 +37,7 @@ type SocialGenericOAuth struct { } func NewGenericOAuthProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGenericOAuth, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/github_oauth.go b/pkg/login/social/github_oauth.go index 4d01b8bd3de..97a868924f5 100644 --- a/pkg/login/social/github_oauth.go +++ b/pkg/login/social/github_oauth.go @@ -50,7 +50,7 @@ var ( ) func NewGitHubProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGithub, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/gitlab_oauth.go b/pkg/login/social/gitlab_oauth.go index 08bed52a2b4..a985a29008a 100644 --- a/pkg/login/social/gitlab_oauth.go +++ b/pkg/login/social/gitlab_oauth.go @@ -50,7 +50,7 @@ type userData struct { } func NewGitLabProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGitlab, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/google_oauth.go b/pkg/login/social/google_oauth.go index bc56b5e9b46..c9b559c9247 100644 --- a/pkg/login/social/google_oauth.go +++ b/pkg/login/social/google_oauth.go @@ -37,7 +37,7 @@ type googleUserData struct { } func NewGoogleProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGoogle, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/grafana_com_oauth.go b/pkg/login/social/grafana_com_oauth.go index 24a1f908265..56a9b167b13 100644 --- a/pkg/login/social/grafana_com_oauth.go +++ b/pkg/login/social/grafana_com_oauth.go @@ -29,7 +29,7 @@ type OrgRecord struct { } func NewGrafanaComProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialGrafanaCom, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/okta_oauth.go b/pkg/login/social/okta_oauth.go index 6d6c2a9de6e..6faf82dd721 100644 --- a/pkg/login/social/okta_oauth.go +++ b/pkg/login/social/okta_oauth.go @@ -44,7 +44,7 @@ type OktaClaims struct { } func NewOktaProvider(settings map[string]any, cfg *setting.Cfg, features *featuremgmt.FeatureManager) (*SocialOkta, error) { - info, err := createOAuthInfoFromKeyValues(settings) + info, err := CreateOAuthInfoFromKeyValues(settings) if err != nil { return nil, err } diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 9994aa9fe7f..f189f8e5365 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -44,38 +44,38 @@ type SocialService struct { } type OAuthInfo struct { - ApiUrl string `mapstructure:"api_url" toml:"api_url"` - AuthUrl string `mapstructure:"auth_url" toml:"auth_url"` - AuthStyle string `mapstructure:"auth_style" toml:"auth_style"` - ClientId string `mapstructure:"client_id" toml:"client_id"` - ClientSecret string `mapstructure:"client_secret" toml:"-"` - EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name"` - EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path"` - EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes"` - GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path"` - HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain"` - Icon string `mapstructure:"icon" toml:"icon"` - Name string `mapstructure:"name" toml:"name"` - RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path"` - TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path"` - TeamsUrl string `mapstructure:"teams_url" toml:"teams_url"` - TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca"` - TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert"` - TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key"` - TokenUrl string `mapstructure:"token_url" toml:"token_url"` - AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains"` - AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups"` - Scopes []string `mapstructure:"scopes" toml:"scopes"` - AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin"` - AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up"` - AutoLogin bool `mapstructure:"auto_login" toml:"auto_login"` - Enabled bool `mapstructure:"enabled" toml:"enabled"` - RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict"` - TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure"` - UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce"` - UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token"` - SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url"` - Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty"` + ApiUrl string `mapstructure:"api_url" toml:"api_url" json:"apiUrl"` + AuthUrl string `mapstructure:"auth_url" toml:"auth_url" json:"authUrl"` + AuthStyle string `mapstructure:"auth_style" toml:"auth_style" json:"authStyle"` + ClientId string `mapstructure:"client_id" toml:"client_id" json:"clientId"` + ClientSecret string `mapstructure:"client_secret" toml:"-" json:"clientSecret"` + EmailAttributeName string `mapstructure:"email_attribute_name" toml:"email_attribute_name" json:"emailAttributeName"` + EmailAttributePath string `mapstructure:"email_attribute_path" toml:"email_attribute_path" json:"emailAttributePath"` + EmptyScopes bool `mapstructure:"empty_scopes" toml:"empty_scopes" json:"emptyScopes"` + GroupsAttributePath string `mapstructure:"groups_attribute_path" toml:"groups_attribute_path" json:"groupsAttributePath"` + HostedDomain string `mapstructure:"hosted_domain" toml:"hosted_domain" json:"hostedDomain"` + Icon string `mapstructure:"icon" toml:"icon" json:"icon"` + Name string `mapstructure:"name" toml:"name" json:"name"` + RoleAttributePath string `mapstructure:"role_attribute_path" toml:"role_attribute_path" json:"roleAttributePath"` + TeamIdsAttributePath string `mapstructure:"team_ids_attribute_path" toml:"team_ids_attribute_path" json:"teamIdsAttributePath"` + TeamsUrl string `mapstructure:"teams_url" toml:"teams_url" json:"teamsUrl"` + TlsClientCa string `mapstructure:"tls_client_ca" toml:"tls_client_ca" json:"tlsClientCa"` + TlsClientCert string `mapstructure:"tls_client_cert" toml:"tls_client_cert" json:"tlsClientCert"` + TlsClientKey string `mapstructure:"tls_client_key" toml:"tls_client_key" json:"tlsClientKey"` + TokenUrl string `mapstructure:"token_url" toml:"token_url" json:"tokenUrl"` + AllowedDomains []string `mapstructure:"allowed_domains" toml:"allowed_domains" json:"allowedDomains"` + AllowedGroups []string `mapstructure:"allowed_groups" toml:"allowed_groups" json:"allowedGroups"` + Scopes []string `mapstructure:"scopes" toml:"scopes" json:"scopes"` + AllowAssignGrafanaAdmin bool `mapstructure:"allow_assign_grafana_admin" toml:"allow_assign_grafana_admin" json:"allowAssignGrafanaAdmin"` + AllowSignup bool `mapstructure:"allow_sign_up" toml:"allow_sign_up" json:"allowSignup"` + AutoLogin bool `mapstructure:"auto_login" toml:"auto_login" json:"autoLogin"` + Enabled bool `mapstructure:"enabled" toml:"enabled" json:"enabled"` + RoleAttributeStrict bool `mapstructure:"role_attribute_strict" toml:"role_attribute_strict" json:"roleAttributeStrict"` + TlsSkipVerify bool `mapstructure:"tls_skip_verify_insecure" toml:"tls_skip_verify_insecure" json:"tlsSkipVerify"` + UsePKCE bool `mapstructure:"use_pkce" toml:"use_pkce" json:"usePKCE"` + UseRefreshToken bool `mapstructure:"use_refresh_token" toml:"use_refresh_token" json:"useRefreshToken"` + SignoutRedirectUrl string `mapstructure:"signout_redirect_url" toml:"signout_redirect_url" json:"signoutRedirectUrl"` + Extra map[string]string `mapstructure:",remain" toml:"extra,omitempty" json:"extra"` } func ProvideService(cfg *setting.Cfg, @@ -97,7 +97,7 @@ func ProvideService(cfg *setting.Cfg, sec := cfg.Raw.Section("auth." + name) settingsKVs := convertIniSectionToMap(sec) - info, err := createOAuthInfoFromKeyValues(settingsKVs) + info, err := CreateOAuthInfoFromKeyValues(settingsKVs) if err != nil { ss.log.Error("Failed to create OAuthInfo for provider", "error", err, "provider", name) continue diff --git a/pkg/services/ssosettings/api/api.go b/pkg/services/ssosettings/api/api.go index 926802854da..e6840ed471a 100644 --- a/pkg/services/ssosettings/api/api.go +++ b/pkg/services/ssosettings/api/api.go @@ -61,7 +61,18 @@ func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Re return response.Error(500, "Failed to get providers", err) } - return response.JSON(http.StatusOK, providers) + dtos := make([]*models.SSOSettingsDTO, 0) + for _, provider := range providers { + dto, err := provider.ToSSOSettingsDTO() + if err != nil { + api.Log.Warn("Failed to convert SSO Settings for provider " + provider.Provider) + continue + } + + dtos = append(dtos, dto) + } + + return response.JSON(http.StatusOK, dtos) } func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Response { @@ -75,7 +86,12 @@ func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Respons return response.Error(http.StatusNotFound, "The provider was not found", err) } - return response.JSON(http.StatusOK, settings) + dto, err := settings.ToSSOSettingsDTO() + if err != nil { + return response.Error(http.StatusInternalServerError, "The provider is invalid", err) + } + + return response.JSON(http.StatusOK, dto) } func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Response { @@ -84,12 +100,19 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp return response.Error(http.StatusBadRequest, "Missing key", nil) } - var newSettings models.SSOSetting - if err := web.Bind(c.Req, &newSettings); err != nil { + var settingsDTO models.SSOSettingsDTO + if err := web.Bind(c.Req, &settingsDTO); err != nil { return response.Error(http.StatusBadRequest, "Failed to parse request body", err) } - err := api.SSOSettingsService.Upsert(c.Req.Context(), key, newSettings.Settings) + settings, err := settingsDTO.ToSSOSettings() + if err != nil { + return response.Error(http.StatusBadRequest, "Invalid request body", err) + } + + settings.Provider = key + + err = api.SSOSettingsService.Upsert(c.Req.Context(), *settings) // TODO: first check whether the error is referring to validation errors // other error diff --git a/pkg/services/ssosettings/database/database.go b/pkg/services/ssosettings/database/database.go index 2923d2900e0..4d1da5ed6a4 100644 --- a/pkg/services/ssosettings/database/database.go +++ b/pkg/services/ssosettings/database/database.go @@ -31,8 +31,8 @@ func ProvideStore(sqlStore db.DB) *SSOSettingsStore { var _ ssosettings.Store = (*SSOSettingsStore)(nil) -func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) { - result := models.SSOSetting{Provider: provider} +func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) { + result := models.SSOSettingsDTO{Provider: provider} err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { var err error sess.Table("sso_setting") @@ -53,14 +53,19 @@ func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SS return nil, err } - return &result, nil + dto, err := result.ToSSOSettings() + if err != nil { + return nil, err + } + + return dto, nil } -func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, error) { - result := make([]*models.SSOSetting, 0) +func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSettings, error) { + dtos := make([]*models.SSOSettingsDTO, 0) err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { sess.Table("sso_setting") - err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&result) + err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&dtos) if err != nil { return err @@ -73,13 +78,29 @@ func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, erro return nil, err } - return result, nil + settings := make([]*models.SSOSettings, 0) + for _, dto := range dtos { + item, err := dto.ToSSOSettings() + if err != nil { + s.log.Warn("Failed to convert DB settings to SSOSettings for provider " + dto.Provider) + continue + } + + settings = append(settings, item) + } + + return settings, nil } -func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { +func (s *SSOSettingsStore) Upsert(ctx context.Context, settings models.SSOSettings) error { + dto, err := settings.ToSSOSettingsDTO() + if err != nil { + return err + } + return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - existing := &models.SSOSetting{ - Provider: provider, + existing := &models.SSOSettingsDTO{ + Provider: dto.Provider, IsDeleted: false, } found, err := sess.UseBool("is_deleted").Exist(existing) @@ -90,17 +111,17 @@ func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map now := timeNow().UTC() if found { - updated := &models.SSOSetting{ - Settings: data, + updated := &models.SSOSettingsDTO{ + Settings: dto.Settings, Updated: now, IsDeleted: false, } _, err = sess.UseBool("is_deleted").Update(updated, existing) } else { - _, err = sess.Insert(&models.SSOSetting{ + _, err = sess.Insert(&models.SSOSettingsDTO{ ID: uuid.New().String(), - Provider: provider, - Settings: data, + Provider: dto.Provider, + Settings: dto.Settings, Created: now, Updated: now, }) @@ -116,7 +137,7 @@ func (s *SSOSettingsStore) Patch(ctx context.Context, provider string, data map[ func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error { return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - existing := &models.SSOSetting{ + existing := &models.SSOSettingsDTO{ Provider: provider, IsDeleted: false, } diff --git a/pkg/services/ssosettings/database/database_test.go b/pkg/services/ssosettings/database/database_test.go index eb5831b399e..ee8902889a7 100644 --- a/pkg/services/ssosettings/database/database_test.go +++ b/pkg/services/ssosettings/database/database_test.go @@ -7,9 +7,9 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/ssosettings/models" @@ -27,24 +27,27 @@ func TestIntegrationGetSSOSettings(t *testing.T) { sqlStore = db.InitTestDB(t) ssoSettingsStore = ProvideStore(sqlStore) - err := insertSSOSetting(ssoSettingsStore, "azuread", nil) + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + }, + } + err := populateSSOSettings(sqlStore, template, "azuread") require.NoError(t, err) } t.Run("returns existing SSO settings", func(t *testing.T) { setup() - expected := &models.SSOSetting{ - Provider: "azuread", - Settings: map[string]interface{}{ - "enabled": true, - }, + expected := &models.SSOSettings{ + Provider: "azuread", + OAuthSettings: &social.OAuthInfo{Enabled: true}, } actual, err := ssoSettingsStore.Get(context.Background(), "azuread") require.NoError(t, err) - require.True(t, maps.Equal(expected.Settings, actual.Settings)) + require.Equal(t, expected.OAuthSettings, actual.OAuthSettings) }) t.Run("returns not found if the SSO setting is missing for the specified provider", func(t *testing.T) { @@ -83,18 +86,21 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { mockTimeNow(time.Now()) defer resetTimeNow() - provider := "azuread" - settings := map[string]interface{}{ - "enabled": true, - "client_id": "azuread-client", + settings := models.SSOSettings{ + Provider: "azuread", + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "azuread-client", + }, } - err := ssoSettingsStore.Upsert(context.Background(), provider, settings) + err := ssoSettingsStore.Upsert(context.Background(), settings) require.NoError(t, err) - actual, err := getSSOSettingsByProvider(sqlStore, provider, false) + actual, err := getSSOSettingsByProvider(sqlStore, settings.Provider, false) require.NoError(t, err) - require.Equal(t, settings, actual.Settings) + require.EqualValues(t, settings.OAuthSettings, actual.OAuthSettings) + require.NotEmpty(t, actual.ID) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created)) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) @@ -111,25 +117,30 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { defer resetTimeNow() provider := "github" - settings := map[string]interface{}{ - "enabled": true, - "client_id": "github-client", - "client_secret": "this-is-a-secret", + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "github-client", + ClientSecret: "this-is-a-secret", + }, } - err := populateSSOSettings(sqlStore, settings, false, provider) + err := populateSSOSettings(sqlStore, template, provider) require.NoError(t, err) - newSettings := map[string]interface{}{ - "enabled": true, - "client_id": "new-github-client", - "client_secret": "this-is-a-new-secret", + newSettings := models.SSOSettings{ + Provider: provider, + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "new-github-client", + ClientSecret: "this-is-a-new-secret", + }, } - err = ssoSettingsStore.Upsert(context.Background(), provider, newSettings) + err = ssoSettingsStore.Upsert(context.Background(), newSettings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, provider, false) require.NoError(t, err) - require.Equal(t, newSettings, actual.Settings) + require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) deleted, notDeleted, err := getSSOSettingsCountByDeleted(sqlStore) @@ -145,32 +156,38 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { defer resetTimeNow() provider := "azuread" - settings := map[string]interface{}{ - "enabled": true, - "client_id": "azuread-client", - "client_secret": "this-is-a-secret", + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "azuread-client", + ClientSecret: "this-is-a-secret", + }, + IsDeleted: true, } - err := populateSSOSettings(sqlStore, settings, true, provider) + err := populateSSOSettings(sqlStore, template, provider) require.NoError(t, err) - newSettings := map[string]interface{}{ - "enabled": true, - "client_id": "new-azuread-client", - "client_secret": "this-is-a-new-secret", + newSettings := models.SSOSettings{ + Provider: provider, + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "new-azuread-client", + ClientSecret: "this-is-a-new-secret", + }, } - err = ssoSettingsStore.Upsert(context.Background(), provider, newSettings) + err = ssoSettingsStore.Upsert(context.Background(), newSettings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, provider, false) require.NoError(t, err) - require.Equal(t, newSettings, actual.Settings) + require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Created)) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) old, err := getSSOSettingsByProvider(sqlStore, provider, true) require.NoError(t, err) - require.Equal(t, settings, old.Settings) + require.Equal(t, template.OAuthSettings, old.OAuthSettings) }) t.Run("replaces the settings only for the specified provider leaving the other provider's settings unchanged", func(t *testing.T) { @@ -180,31 +197,36 @@ func TestIntegrationUpsertSSOSettings(t *testing.T) { defer resetTimeNow() providers := []string{"github", "gitlab", "google"} - settings := map[string]interface{}{ - "enabled": true, - "client_id": "my-client", - "client_secret": "this-is-a-secret", + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "my-client", + ClientSecret: "this-is-a-secret", + }, } - err := populateSSOSettings(sqlStore, settings, false, providers...) + err := populateSSOSettings(sqlStore, template, providers...) require.NoError(t, err) - newSettings := map[string]interface{}{ - "enabled": true, - "client_id": "my-new-client", - "client_secret": "this-is-a-new-secret", + newSettings := models.SSOSettings{ + Provider: providers[0], + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + ClientId: "my-new-client", + ClientSecret: "this-is-my-new-secret", + }, } - err = ssoSettingsStore.Upsert(context.Background(), providers[0], newSettings) + err = ssoSettingsStore.Upsert(context.Background(), newSettings) require.NoError(t, err) actual, err := getSSOSettingsByProvider(sqlStore, providers[0], false) require.NoError(t, err) - require.Equal(t, newSettings, actual.Settings) + require.Equal(t, newSettings.OAuthSettings, actual.OAuthSettings) require.Equal(t, formatTime(timeNow().UTC()), formatTime(actual.Updated)) for index := 1; index < len(providers); index++ { existing, err := getSSOSettingsByProvider(sqlStore, providers[index], false) require.NoError(t, err) - require.Equal(t, settings, existing.Settings) + require.EqualValues(t, template.OAuthSettings, existing.OAuthSettings) } }) } @@ -221,14 +243,20 @@ func TestIntegrationListSSOSettings(t *testing.T) { sqlStore = db.InitTestDB(t) ssoSettingsStore = ProvideStore(sqlStore) - err := insertSSOSetting(ssoSettingsStore, "azuread", map[string]interface{}{ - "enabled": true, - }) + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + }, + } + err := populateSSOSettings(sqlStore, template, "azuread") require.NoError(t, err) - err = insertSSOSetting(ssoSettingsStore, "okta", map[string]interface{}{ - "enabled": false, - }) + template = models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: false, + }, + } + err = populateSSOSettings(sqlStore, template, "okta") require.NoError(t, err) } @@ -259,8 +287,12 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { setup() providers := []string{"azuread", "github", "google"} - - err := populateSSOSettings(sqlStore, nil, false, providers...) + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + }, + } + err := populateSSOSettings(sqlStore, template, providers...) require.NoError(t, err) err = ssoSettingsStore.Delete(context.Background(), providers[0]) @@ -277,8 +309,12 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { providers := []string{"github", "google", "okta"} invalidProvider := "azuread" - - err := populateSSOSettings(sqlStore, nil, false, providers...) + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + }, + } + err := populateSSOSettings(sqlStore, template, providers...) require.NoError(t, err) err = ssoSettingsStore.Delete(context.Background(), invalidProvider) @@ -295,8 +331,13 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { setup() providers := []string{"azuread", "github", "google"} - - err := populateSSOSettings(sqlStore, nil, true, providers...) + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + }, + IsDeleted: true, + } + err := populateSSOSettings(sqlStore, template, providers...) require.NoError(t, err) err = ssoSettingsStore.Delete(context.Background(), providers[0]) @@ -313,11 +354,15 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { setup() provider := "azuread" - + template := models.SSOSettings{ + OAuthSettings: &social.OAuthInfo{ + Enabled: true, + }, + } // insert sso for the same provider 2 times in the database - err := populateSSOSettings(sqlStore, nil, false, provider) + err := populateSSOSettings(sqlStore, template, provider) require.NoError(t, err) - err = populateSSOSettings(sqlStore, nil, false, provider) + err = populateSSOSettings(sqlStore, template, provider) require.NoError(t, err) err = ssoSettingsStore.Delete(context.Background(), provider) @@ -330,25 +375,19 @@ func TestIntegrationDeleteSSOSettings(t *testing.T) { }) } -func insertSSOSetting(ssoSettingsStore ssosettings.Store, provider string, settings map[string]interface{}) error { - if settings == nil { - settings = map[string]interface{}{ - "enabled": true, - } - } - return ssoSettingsStore.Upsert(context.Background(), provider, settings) -} - -func populateSSOSettings(sqlStore *sqlstore.SQLStore, settings map[string]interface{}, deleted bool, providers ...string) error { +func populateSSOSettings(sqlStore *sqlstore.SQLStore, template models.SSOSettings, providers ...string) error { return sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error { for _, provider := range providers { - _, err := sess.Insert(&models.SSOSetting{ - ID: uuid.New().String(), - Provider: provider, - Settings: settings, - Created: timeNow().UTC(), - IsDeleted: deleted, - }) + template.Provider = provider + template.ID = uuid.New().String() + template.Created = timeNow().UTC() + + dto, err := template.ToSSOSettingsDTO() + if err != nil { + return err + } + + _, err = sess.Insert(dto) if err != nil { return err } @@ -370,8 +409,8 @@ func getSSOSettingsCountByDeleted(sqlStore *sqlstore.SQLStore) (deleted, notDele return } -func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSetting, error) { - var model models.SSOSetting +func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, deleted bool) (*models.SSOSettings, error) { + var model models.SSOSettingsDTO var err error err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error { @@ -379,7 +418,16 @@ func getSSOSettingsByProvider(sqlStore *sqlstore.SQLStore, provider string, dele return err }) - return &model, err + if err != nil { + return nil, err + } + + settings, err := model.ToSSOSettings() + if err != nil { + return nil, err + } + + return settings, err } func mockTimeNow(timeSeed time.Time) { diff --git a/pkg/services/ssosettings/models/models.go b/pkg/services/ssosettings/models/models.go index ec58da555ef..e5abefd188a 100644 --- a/pkg/services/ssosettings/models/models.go +++ b/pkg/services/ssosettings/models/models.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/grafana/grafana/pkg/services/featuremgmt/strcase" + "github.com/grafana/grafana/pkg/login/social" ) type SettingsSource int @@ -26,8 +26,18 @@ func (s SettingsSource) MarshalJSON() ([]byte, error) { } } -type SSOSetting struct { - ID string `xorm:"id pk" json:"-"` +type SSOSettings struct { + ID string + Provider string + OAuthSettings *social.OAuthInfo + Created time.Time + Updated time.Time + IsDeleted bool + Source SettingsSource +} + +type SSOSettingsDTO struct { + ID string `xorm:"id pk" json:"id"` Provider string `xorm:"provider" json:"provider"` Settings map[string]interface{} `xorm:"settings" json:"settings"` Created time.Time `xorm:"created" json:"-"` @@ -37,51 +47,51 @@ type SSOSetting struct { } // TableName returns the table name (needed for Xorm) -func (s SSOSetting) TableName() string { +func (s SSOSettingsDTO) TableName() string { return "sso_setting" } -// MarshalJSON implements the json.Marshaler interface and converts the s.Settings from map[string]any to map[string]any in camelCase -func (s SSOSetting) MarshalJSON() ([]byte, error) { - type Alias SSOSetting - aux := &struct { - *Alias - }{ - Alias: (*Alias)(&s), +func (s SSOSettingsDTO) ToSSOSettings() (*SSOSettings, error) { + settingsEncoded, err := json.Marshal(s.Settings) + if err != nil { + return nil, err } - settings := make(map[string]any) - for k, v := range aux.Settings { - settings[strcase.ToLowerCamel(k)] = v + var settings social.OAuthInfo + err = json.Unmarshal(settingsEncoded, &settings) + if err != nil { + return nil, err } - aux.Settings = settings - return json.Marshal(aux) + return &SSOSettings{ + ID: s.ID, + Provider: s.Provider, + OAuthSettings: &settings, + Created: s.Created, + Updated: s.Updated, + IsDeleted: s.IsDeleted, + }, nil } -// UnmarshalJSON implements the json.Unmarshaler interface and converts the settings from map[string]any camelCase to map[string]interface{} snake_case -func (s *SSOSetting) UnmarshalJSON(data []byte) error { - type Alias SSOSetting - aux := &struct { - *Alias - }{ - Alias: (*Alias)(s), +func (s SSOSettings) ToSSOSettingsDTO() (*SSOSettingsDTO, error) { + settingsEncoded, err := json.Marshal(s.OAuthSettings) + if err != nil { + return nil, err } - if err := json.Unmarshal(data, &aux); err != nil { - return err + var settings map[string]interface{} + err = json.Unmarshal(settingsEncoded, &settings) + if err != nil { + return nil, err } - settings := make(map[string]any) - for k, v := range aux.Settings { - settings[strcase.ToSnake(k)] = v - } - - s.Settings = settings - return nil -} - -type SSOSettingsResponse struct { - Settings map[string]interface{} `json:"settings"` - Provider string `json:"type"` + return &SSOSettingsDTO{ + ID: s.ID, + Provider: s.Provider, + Settings: settings, + Created: s.Created, + Updated: s.Updated, + IsDeleted: s.IsDeleted, + Source: s.Source, + }, nil } diff --git a/pkg/services/ssosettings/ssosettings.go b/pkg/services/ssosettings/ssosettings.go index 1b6994fc613..5130f852184 100644 --- a/pkg/services/ssosettings/ssosettings.go +++ b/pkg/services/ssosettings/ssosettings.go @@ -20,11 +20,11 @@ var ( //go:generate mockery --name Service --structname MockService --outpkg ssosettingstests --filename service_mock.go --output ./ssosettingstests/ type Service interface { // List returns all SSO settings from DB and config files - List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) + List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) // GetForProvider returns the SSO settings for a given provider (DB or config file) - GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) + GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) // Upsert creates or updates the SSO settings for a given provider - Upsert(ctx context.Context, provider string, data map[string]interface{}) error + Upsert(ctx context.Context, settings models.SSOSettings) error // Delete deletes the SSO settings for a given provider (soft delete) Delete(ctx context.Context, provider string) error // Patch updates the specified SSO settings (key-value pairs) for a given provider @@ -52,9 +52,9 @@ type FallbackStrategy interface { // //go:generate mockery --name Store --structname MockStore --outpkg ssosettingstests --filename store_mock.go --output ./ssosettingstests/ type Store interface { - Get(ctx context.Context, provider string) (*models.SSOSetting, error) - List(ctx context.Context) ([]*models.SSOSetting, error) - Upsert(ctx context.Context, provider string, data map[string]interface{}) error + Get(ctx context.Context, provider string) (*models.SSOSettings, error) + List(ctx context.Context) ([]*models.SSOSettings, error) + Upsert(ctx context.Context, settings models.SSOSettings) error Patch(ctx context.Context, provider string, data map[string]interface{}) error Delete(ctx context.Context, provider string) error } diff --git a/pkg/services/ssosettings/ssosettingsimpl/service.go b/pkg/services/ssosettings/ssosettingsimpl/service.go index b3d0d284f04..97cf4897484 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service.go @@ -7,6 +7,7 @@ import ( "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/login/social" ac "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/featuremgmt" @@ -55,29 +56,29 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl, var _ ssosettings.Service = (*SSOSettingsService)(nil) -func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) { - dto, err := s.store.Get(ctx, provider) +func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) { + storeSettings, err := s.store.Get(ctx, provider) if errors.Is(err, ssosettings.ErrNotFound) { - setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) + settings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) if err != nil { return nil, err } - return setting, nil + return settings, nil } if err != nil { return nil, err } - dto.Source = models.DB + storeSettings.Source = models.DB - return dto, nil + return storeSettings, nil } -func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) { - result := make([]*models.SSOSetting, 0, len(ssosettings.AllOAuthProviders)) +func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) { + result := make([]*models.SSOSettings, 0, len(ssosettings.AllOAuthProviders)) storedSettings, err := s.store.List(ctx) if err != nil { @@ -98,12 +99,12 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques settings := getSettingsByProvider(provider, storedSettings) if len(settings) == 0 { // If there is no data in the DB then we need to load the settings using the fallback strategy - setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) + fallbackSettings, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) if err != nil { return nil, err } - settings = append(settings, setting) + settings = append(settings, fallbackSettings) } result = append(result, settings...) } @@ -111,9 +112,9 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques return result, nil } -func (s *SSOSettingsService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { +func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error { // TODO: validation (configurable provider? Contains the required fields? etc) - err := s.store.Upsert(ctx, provider, data) + err := s.store.Upsert(ctx, settings) if err != nil { return err } @@ -140,7 +141,7 @@ func (s *SSOSettingsService) RegisterFallbackStrategy(providerRegex string, stra s.fbStrategies = append(s.fbStrategies, strategy) } -func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSetting, error) { +func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSettings, error) { loadStrategy, ok := s.getFallBackstrategyFor(provider) if !ok { return nil, errors.New("no fallback strategy found for provider: " + provider) @@ -151,18 +152,23 @@ func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Conte return nil, err } - return &models.SSOSetting{ - Provider: provider, - Source: models.System, - Settings: settingsFromSystem, + oAuthInfo, err := social.CreateOAuthInfoFromKeyValues(settingsFromSystem) + if err != nil { + return nil, err + } + + return &models.SSOSettings{ + Provider: provider, + Source: models.System, + OAuthSettings: oAuthInfo, }, nil } -func getSettingsByProvider(provider string, settings []*models.SSOSetting) []*models.SSOSetting { - result := make([]*models.SSOSetting, 0) - for _, setting := range settings { - if setting.Provider == provider { - result = append(result, setting) +func getSettingsByProvider(provider string, settings []*models.SSOSettings) []*models.SSOSettings { + result := make([]*models.SSOSettings, 0) + for _, item := range settings { + if item.Provider == provider { + result = append(result, item) } } return result diff --git a/pkg/services/ssosettings/ssosettingsimpl/service_test.go b/pkg/services/ssosettings/ssosettingsimpl/service_test.go index a25cba5b3fe..1915625b6d4 100644 --- a/pkg/services/ssosettings/ssosettingsimpl/service_test.go +++ b/pkg/services/ssosettings/ssosettingsimpl/service_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" "github.com/grafana/grafana/pkg/services/auth/identity" @@ -22,25 +23,21 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) { testCases := []struct { name string setup func(env testEnv) - want *models.SSOSetting + want *models.SSOSettings wantErr bool }{ { name: "should return successfully", setup: func(env testEnv) { - env.store.ExpectedSSOSetting = &models.SSOSetting{ - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, - Source: models.DB, + env.store.ExpectedSSOSetting = &models.SSOSettings{ + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.DB, } }, - want: &models.SSOSetting{ - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, + want: &models.SSOSettings{ + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, }, wantErr: false, }, @@ -59,12 +56,10 @@ func TestSSOSettingsService_GetForProvider(t *testing.T) { "enabled": true, } }, - want: &models.SSOSetting{ - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, - Source: models.System, + want: &models.SSOSettings{ + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.System, }, wantErr: false, }, @@ -136,26 +131,22 @@ func TestSSOSettingsService_List(t *testing.T) { name string setup func(env testEnv) identity identity.Requester - want []*models.SSOSetting + want []*models.SSOSettings wantErr bool }{ { name: "should return successfully", setup: func(env testEnv) { - env.store.ExpectedSSOSettings = []*models.SSOSetting{ + env.store.ExpectedSSOSettings = []*models.SSOSettings{ { - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, - Source: models.DB, + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.DB, }, { - Provider: "okta", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.DB, + Provider: "okta", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.DB, }, } env.fallbackStrategy.ExpectedIsMatch = true @@ -164,55 +155,41 @@ func TestSSOSettingsService_List(t *testing.T) { } }, identity: defaultIdentity, - want: []*models.SSOSetting{ + want: []*models.SSOSettings{ { - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, - Source: models.DB, + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.DB, }, { - Provider: "okta", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.DB, + Provider: "okta", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.DB, }, { - Provider: "gitlab", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "gitlab", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "generic_oauth", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "generic_oauth", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "google", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "google", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "azuread", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "azuread", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "grafana_com", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "grafana_com", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, }, wantErr: false, @@ -220,20 +197,16 @@ func TestSSOSettingsService_List(t *testing.T) { { name: "should return the settings that the user has access to", setup: func(env testEnv) { - env.store.ExpectedSSOSettings = []*models.SSOSetting{ + env.store.ExpectedSSOSettings = []*models.SSOSettings{ { - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, - Source: models.DB, + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.DB, }, { - Provider: "okta", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.DB, + Provider: "okta", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.DB, }, } env.fallbackStrategy.ExpectedIsMatch = true @@ -242,20 +215,16 @@ func TestSSOSettingsService_List(t *testing.T) { } }, identity: scopedIdentity, - want: []*models.SSOSetting{ + want: []*models.SSOSettings{ { - Provider: "github", - Settings: map[string]interface{}{ - "enabled": true, - }, - Source: models.DB, + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: true}, + Source: models.DB, }, { - Provider: "azuread", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "azuread", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, }, wantErr: false, @@ -270,62 +239,48 @@ func TestSSOSettingsService_List(t *testing.T) { { name: "should use the fallback strategy if store returns empty list", setup: func(env testEnv) { - env.store.ExpectedSSOSettings = []*models.SSOSetting{} + env.store.ExpectedSSOSettings = []*models.SSOSettings{} env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfig = map[string]interface{}{ "enabled": false, } }, identity: defaultIdentity, - want: []*models.SSOSetting{ + want: []*models.SSOSettings{ { - Provider: "github", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "github", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "okta", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "okta", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "gitlab", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "gitlab", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "generic_oauth", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "generic_oauth", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "google", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "google", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "azuread", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "azuread", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, { - Provider: "grafana_com", - Settings: map[string]interface{}{ - "enabled": false, - }, - Source: models.System, + Provider: "grafana_com", + OAuthSettings: &social.OAuthInfo{Enabled: false}, + Source: models.System, }, }, wantErr: false, @@ -333,7 +288,7 @@ func TestSSOSettingsService_List(t *testing.T) { { name: "should return error if any of the fallback strategies was not found", setup: func(env testEnv) { - env.store.ExpectedSSOSettings = []*models.SSOSetting{} + env.store.ExpectedSSOSettings = []*models.SSOSettings{} env.fallbackStrategy.ExpectedIsMatch = false }, identity: defaultIdentity, diff --git a/pkg/services/ssosettings/ssosettingstests/service_mock.go b/pkg/services/ssosettings/ssosettingstests/service_mock.go index ba4f0328ce4..6cb3edbf90a 100644 --- a/pkg/services/ssosettings/ssosettingstests/service_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/service_mock.go @@ -33,19 +33,19 @@ func (_m *MockService) Delete(ctx context.Context, provider string) error { } // GetForProvider provides a mock function with given fields: ctx, provider -func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) { +func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*models.SSOSettings, error) { ret := _m.Called(ctx, provider) - var r0 *models.SSOSetting + var r0 *models.SSOSettings var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSettings, error)); ok { return rf(ctx, provider) } - if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSettings); ok { r0 = rf(ctx, provider) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*models.SSOSetting) + r0 = ret.Get(0).(*models.SSOSettings) } } @@ -59,19 +59,19 @@ func (_m *MockService) GetForProvider(ctx context.Context, provider string) (*mo } // List provides a mock function with given fields: ctx, requester -func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) { +func (_m *MockService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSettings, error) { ret := _m.Called(ctx, requester) - var r0 []*models.SSOSetting + var r0 []*models.SSOSettings var r1 error - if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSetting, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) ([]*models.SSOSettings, error)); ok { return rf(ctx, requester) } - if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSetting); ok { + if rf, ok := ret.Get(0).(func(context.Context, identity.Requester) []*models.SSOSettings); ok { r0 = rf(ctx, requester) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*models.SSOSetting) + r0 = ret.Get(0).([]*models.SSOSettings) } } @@ -108,13 +108,13 @@ func (_m *MockService) Reload(ctx context.Context, provider string) { _m.Called(ctx, provider) } -// Upsert provides a mock function with given fields: ctx, provider, data -func (_m *MockService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { - ret := _m.Called(ctx, provider, data) +// Upsert provides a mock function with given fields: ctx, settings +func (_m *MockService) Upsert(ctx context.Context, settings models.SSOSettings) error { + ret := _m.Called(ctx, settings) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok { - r0 = rf(ctx, provider, data) + if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok { + r0 = rf(ctx, settings) } else { r0 = ret.Error(0) } diff --git a/pkg/services/ssosettings/ssosettingstests/store_fake.go b/pkg/services/ssosettings/ssosettingstests/store_fake.go index 9642c35e33c..1467756a397 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_fake.go +++ b/pkg/services/ssosettings/ssosettingstests/store_fake.go @@ -10,8 +10,8 @@ import ( var _ ssosettings.Store = (*FakeStore)(nil) type FakeStore struct { - ExpectedSSOSetting *models.SSOSetting - ExpectedSSOSettings []*models.SSOSetting + ExpectedSSOSetting *models.SSOSettings + ExpectedSSOSettings []*models.SSOSettings ExpectedError error } @@ -19,15 +19,15 @@ func NewFakeStore() *FakeStore { return &FakeStore{} } -func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) { +func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) { return f.ExpectedSSOSetting, f.ExpectedError } -func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSetting, error) { +func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) { return f.ExpectedSSOSettings, f.ExpectedError } -func (f *FakeStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { +func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error { return f.ExpectedError } diff --git a/pkg/services/ssosettings/ssosettingstests/store_mock.go b/pkg/services/ssosettings/ssosettingstests/store_mock.go index 9683c7b7d67..55214c18fe8 100644 --- a/pkg/services/ssosettings/ssosettingstests/store_mock.go +++ b/pkg/services/ssosettings/ssosettingstests/store_mock.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.27.1. DO NOT EDIT. +// Code generated by mockery v2.37.1. DO NOT EDIT. package ssosettingstests @@ -29,19 +29,19 @@ func (_m *MockStore) Delete(ctx context.Context, provider string) error { } // Get provides a mock function with given fields: ctx, provider -func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) { +func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSettings, error) { ret := _m.Called(ctx, provider) - var r0 *models.SSOSetting + var r0 *models.SSOSettings var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSettings, error)); ok { return rf(ctx, provider) } - if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSettings); ok { r0 = rf(ctx, provider) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*models.SSOSetting) + r0 = ret.Get(0).(*models.SSOSettings) } } @@ -54,20 +54,20 @@ func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetti return r0, r1 } -// GetAll provides a mock function with given fields: ctx -func (_m *MockStore) GetAll(ctx context.Context) ([]*models.SSOSetting, error) { +// List provides a mock function with given fields: ctx +func (_m *MockStore) List(ctx context.Context) ([]*models.SSOSettings, error) { ret := _m.Called(ctx) - var r0 []*models.SSOSetting + var r0 []*models.SSOSettings var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSetting, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSettings, error)); ok { return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSetting); ok { + if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSettings); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*models.SSOSetting) + r0 = ret.Get(0).([]*models.SSOSettings) } } @@ -94,13 +94,13 @@ func (_m *MockStore) Patch(ctx context.Context, provider string, data map[string return r0 } -// Upsert provides a mock function with given fields: ctx, provider, data -func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { - ret := _m.Called(ctx, provider, data) +// Upsert provides a mock function with given fields: ctx, settings +func (_m *MockStore) Upsert(ctx context.Context, settings models.SSOSettings) error { + ret := _m.Called(ctx, settings) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok { - r0 = rf(ctx, provider, data) + if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok { + r0 = rf(ctx, settings) } else { r0 = ret.Error(0) } @@ -108,13 +108,12 @@ func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[strin return r0 } -type mockConstructorTestingTNewMockStore interface { +// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStore(t interface { mock.TestingT Cleanup(func()) -} - -// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockStore(t mockConstructorTestingTNewMockStore) *MockStore { +}) *MockStore { mock := &MockStore{} mock.Mock.Test(t)