From 5285e9503be5702680acb2b52a6bda0632f4603d Mon Sep 17 00:00:00 2001 From: Misi Date: Wed, 8 Nov 2023 10:50:01 +0100 Subject: [PATCH] Auth: SSO settings foundations (#77724) * inital changes, db migration * changes * Implement basic GetAll, Delete * Add first batch of tests * Add more tests * Add service tests for GetForProvider, List * Update http_server.go + wire.go * Lint + update fixed role * Update CODEOWNERS * Change API init * Change roles, rename * Review with @kalleep * Revert a mistakenly changed part * Updates based on @dmihai 's feedback --------- Co-authored-by: Karl Persson --- .github/CODEOWNERS | 1 + .../feature-toggles/index.md | 1 + .../src/types/featureToggles.gen.ts | 1 + .../backgroundsvcs/background_services.go | 3 +- pkg/server/wire.go | 4 + pkg/services/accesscontrol/roles.go | 4 +- pkg/services/featuremgmt/registry.go | 8 + pkg/services/featuremgmt/toggles_gen.csv | 1 + pkg/services/featuremgmt/toggles_gen.go | 4 + .../sqlstore/migrations/migrations.go | 3 + .../migrations/ssosettings/migrations.go | 19 + pkg/services/ssosettings/api/api.go | 114 ++++++ pkg/services/ssosettings/api/api_test.go | 3 + pkg/services/ssosettings/database/database.go | 126 ++++++ .../ssosettings/database/database_test.go | 215 ++++++++++ pkg/services/ssosettings/errors.go | 7 + pkg/services/ssosettings/models/models.go | 45 ++ pkg/services/ssosettings/ssosettings.go | 58 +++ .../ssosettings/ssosettingsimpl/service.go | 178 ++++++++ .../ssosettingsimpl/service_test.go | 387 ++++++++++++++++++ .../fallback_strategy_fake.go | 22 + .../ssosettingstests/store_fake.go | 40 ++ .../ssosettingstests/store_mock.go | 124 ++++++ .../ssosettings/strategies/oauth_strategy.go | 71 ++++ 24 files changed, 1436 insertions(+), 3 deletions(-) create mode 100644 pkg/services/sqlstore/migrations/ssosettings/migrations.go create mode 100644 pkg/services/ssosettings/api/api.go create mode 100644 pkg/services/ssosettings/api/api_test.go create mode 100644 pkg/services/ssosettings/database/database.go create mode 100644 pkg/services/ssosettings/database/database_test.go create mode 100644 pkg/services/ssosettings/errors.go create mode 100644 pkg/services/ssosettings/models/models.go create mode 100644 pkg/services/ssosettings/ssosettings.go create mode 100644 pkg/services/ssosettings/ssosettingsimpl/service.go create mode 100644 pkg/services/ssosettings/ssosettingsimpl/service_test.go create mode 100644 pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go create mode 100644 pkg/services/ssosettings/ssosettingstests/store_fake.go create mode 100644 pkg/services/ssosettings/ssosettingstests/store_mock.go create mode 100644 pkg/services/ssosettings/strategies/oauth_strategy.go diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 22d11ec4328..c37bd586f41 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -127,6 +127,7 @@ /pkg/services/secrets/ @grafana/backend-platform /pkg/services/shorturls/ @grafana/backend-platform /pkg/services/sqlstore/ @grafana/backend-platform +/pkg/services/ssosettings/ @grafana/identity-access-team /pkg/services/star/ @grafana/backend-platform /pkg/services/stats/ @grafana/backend-platform /pkg/services/tag/ @grafana/backend-platform diff --git a/docs/sources/setup-grafana/configure-grafana/feature-toggles/index.md b/docs/sources/setup-grafana/configure-grafana/feature-toggles/index.md index 80cd6398105..fd74f62cce2 100644 --- a/docs/sources/setup-grafana/configure-grafana/feature-toggles/index.md +++ b/docs/sources/setup-grafana/configure-grafana/feature-toggles/index.md @@ -176,3 +176,4 @@ The following toggles require explicitly setting Grafana's [app mode]({{< relref | `idForwarding` | Generate signed id token for identity that can be forwarded to plugins and external services | | `externalServiceAccounts` | Automatic service account and token setup for plugins | | `panelTitleSearchInV1` | Enable searching for dashboards using panel title in search v1 | +| `ssoSettingsApi` | Enables the SSO settings API | diff --git a/packages/grafana-data/src/types/featureToggles.gen.ts b/packages/grafana-data/src/types/featureToggles.gen.ts index e38ee6933b3..7ff2582a88a 100644 --- a/packages/grafana-data/src/types/featureToggles.gen.ts +++ b/packages/grafana-data/src/types/featureToggles.gen.ts @@ -160,4 +160,5 @@ export interface FeatureToggles { dashboardSceneForViewers?: boolean; panelFilterVariable?: boolean; pdfTables?: boolean; + ssoSettingsApi?: boolean; } diff --git a/pkg/registry/backgroundsvcs/background_services.go b/pkg/registry/backgroundsvcs/background_services.go index 8e316697aa4..dd5b693a0b9 100644 --- a/pkg/registry/backgroundsvcs/background_services.go +++ b/pkg/registry/backgroundsvcs/background_services.go @@ -36,6 +36,7 @@ import ( secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager" "github.com/grafana/grafana/pkg/services/serviceaccounts" samanager "github.com/grafana/grafana/pkg/services/serviceaccounts/manager" + "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/store" "github.com/grafana/grafana/pkg/services/store/entity" "github.com/grafana/grafana/pkg/services/store/sanitizer" @@ -63,7 +64,7 @@ func ProvideBackgroundServiceRegistry( _ serviceaccounts.Service, _ *guardian.Provider, _ *plugindashboardsservice.DashboardUpdater, _ *sanitizer.Provider, _ *grpcserver.HealthService, _ entity.EntityStoreServer, _ *grpcserver.ReflectionService, _ *ldapapi.Service, - _ *apiregistry.Service, _ auth.IDService, _ *teamapi.TeamAPI, + _ *apiregistry.Service, _ auth.IDService, _ *teamapi.TeamAPI, _ ssosettings.Service, ) *BackgroundServiceRegistry { return NewBackgroundServiceRegistry( httpServer, diff --git a/pkg/server/wire.go b/pkg/server/wire.go index 61bebfdf71f..0822df70718 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -130,6 +130,8 @@ import ( "github.com/grafana/grafana/pkg/services/signingkeys" "github.com/grafana/grafana/pkg/services/signingkeys/signingkeysimpl" "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/ssosettings" + ssoSettingsImpl "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingsimpl" starApi "github.com/grafana/grafana/pkg/services/star/api" "github.com/grafana/grafana/pkg/services/star/starimpl" "github.com/grafana/grafana/pkg/services/stats/statsimpl" @@ -377,6 +379,8 @@ var wireBasicSet = wire.NewSet( loggermw.Provide, signingkeysimpl.ProvideEmbeddedSigningKeysService, wire.Bind(new(signingkeys.Service), new(*signingkeysimpl.Service)), + ssoSettingsImpl.ProvideService, + wire.Bind(new(ssosettings.Service), new(*ssoSettingsImpl.SSOSettingsService)), idimpl.ProvideService, wire.Bind(new(auth.IDService), new(*idimpl.Service)), grafanaapiserver.WireSet, diff --git a/pkg/services/accesscontrol/roles.go b/pkg/services/accesscontrol/roles.go index e4fc58ce243..a781fb6f8ca 100644 --- a/pkg/services/accesscontrol/roles.go +++ b/pkg/services/accesscontrol/roles.go @@ -202,11 +202,11 @@ var ( Permissions: []Permission{ { Action: ActionSettingsRead, - Scope: ScopeSettingsSAML, + Scope: ScopeSettingsAuth, }, { Action: ActionSettingsWrite, - Scope: ScopeSettingsSAML, + Scope: ScopeSettingsAuth, }, }, } diff --git a/pkg/services/featuremgmt/registry.go b/pkg/services/featuremgmt/registry.go index cf9c374a9bc..d4d59da8614 100644 --- a/pkg/services/featuremgmt/registry.go +++ b/pkg/services/featuremgmt/registry.go @@ -1032,6 +1032,14 @@ var ( FrontendOnly: false, Owner: grafanaSharingSquad, }, + { + Name: "ssoSettingsApi", + Description: "Enables the SSO settings API", + RequiresDevMode: true, + Stage: FeatureStageExperimental, + FrontendOnly: false, + Owner: identityAccessTeam, + }, } ) diff --git a/pkg/services/featuremgmt/toggles_gen.csv b/pkg/services/featuremgmt/toggles_gen.csv index aee638498e0..5e0b28750fe 100644 --- a/pkg/services/featuremgmt/toggles_gen.csv +++ b/pkg/services/featuremgmt/toggles_gen.csv @@ -141,3 +141,4 @@ extractFieldsNameDeduplication,experimental,@grafana/grafana-bi-squad,false,fals dashboardSceneForViewers,experimental,@grafana/dashboards-squad,false,false,false,true panelFilterVariable,experimental,@grafana/dashboards-squad,false,false,false,true pdfTables,privatePreview,@grafana/sharing-squad,false,false,false,false +ssoSettingsApi,experimental,@grafana/identity-access-team,true,false,false,false diff --git a/pkg/services/featuremgmt/toggles_gen.go b/pkg/services/featuremgmt/toggles_gen.go index 70f2a403d79..b6dc6a7222a 100644 --- a/pkg/services/featuremgmt/toggles_gen.go +++ b/pkg/services/featuremgmt/toggles_gen.go @@ -574,4 +574,8 @@ const ( // FlagPdfTables // Enables generating table data as PDF in reporting FlagPdfTables = "pdfTables" + + // FlagSsoSettingsApi + // Enables the SSO settings API + FlagSsoSettingsApi = "ssoSettingsApi" ) diff --git a/pkg/services/sqlstore/migrations/migrations.go b/pkg/services/sqlstore/migrations/migrations.go index e16117c5df5..7373c7c93ae 100644 --- a/pkg/services/sqlstore/migrations/migrations.go +++ b/pkg/services/sqlstore/migrations/migrations.go @@ -7,6 +7,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore/migrations/anonservice" "github.com/grafana/grafana/pkg/services/sqlstore/migrations/oauthserver" "github.com/grafana/grafana/pkg/services/sqlstore/migrations/signingkeys" + "github.com/grafana/grafana/pkg/services/sqlstore/migrations/ssosettings" "github.com/grafana/grafana/pkg/services/sqlstore/migrations/ualert" . "github.com/grafana/grafana/pkg/services/sqlstore/migrator" ) @@ -106,6 +107,8 @@ func (*OSSMigrations) AddMigration(mg *Migrator) { ualert.CreatedFoldersMigration(mg) dashboardFolderMigrations.AddDashboardFolderMigrations(mg) + + ssosettings.AddMigration(mg) } func addStarMigrations(mg *Migrator) { diff --git a/pkg/services/sqlstore/migrations/ssosettings/migrations.go b/pkg/services/sqlstore/migrations/ssosettings/migrations.go new file mode 100644 index 00000000000..822c441d604 --- /dev/null +++ b/pkg/services/sqlstore/migrations/ssosettings/migrations.go @@ -0,0 +1,19 @@ +package ssosettings + +import "github.com/grafana/grafana/pkg/services/sqlstore/migrator" + +func AddMigration(mg *migrator.Migrator) { + var ssoSettingV1 = migrator.Table{ + Name: "sso_setting", + Columns: []*migrator.Column{ + {Name: "id", Type: migrator.DB_NVarchar, Length: 40, IsPrimaryKey: true}, // Store uuidv4 + {Name: "provider", Type: migrator.DB_NVarchar, Length: 255, Nullable: false}, + {Name: "settings", Type: migrator.DB_Text, Nullable: false}, + {Name: "created", Type: migrator.DB_DateTime, Nullable: false}, + {Name: "updated", Type: migrator.DB_DateTime, Nullable: false}, + {Name: "is_deleted", Type: migrator.DB_Bool, Nullable: false, Default: "0"}, + }, + } + + mg.AddMigration("create sso_setting table", migrator.NewAddTableMigration(ssoSettingV1)) +} diff --git a/pkg/services/ssosettings/api/api.go b/pkg/services/ssosettings/api/api.go new file mode 100644 index 00000000000..7b16f1aef30 --- /dev/null +++ b/pkg/services/ssosettings/api/api.go @@ -0,0 +1,114 @@ +package api + +import ( + "github.com/grafana/grafana/pkg/api/response" + "github.com/grafana/grafana/pkg/api/routing" + "github.com/grafana/grafana/pkg/infra/log" + ac "github.com/grafana/grafana/pkg/services/accesscontrol" + contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/services/ssosettings/models" + "github.com/grafana/grafana/pkg/web" +) + +type Api struct { + Log log.Logger + RouteRegister routing.RouteRegister + AccessControl ac.AccessControl + Features *featuremgmt.FeatureManager + SSOSettingsService ssosettings.Service +} + +func ProvideApi( + ssoSettingsSvc ssosettings.Service, + routeRegister routing.RouteRegister, + ac ac.AccessControl, +) *Api { + api := &Api{ + SSOSettingsService: ssoSettingsSvc, + RouteRegister: routeRegister, + AccessControl: ac, + Log: log.New("ssosettings.api"), + } + + return api +} + +// RegisterAPIEndpoints Registers Endpoints on Grafana Router +func (api *Api) RegisterAPIEndpoints() { + api.RouteRegister.Group("/api/v1/sso-settings", func(router routing.RouteRegister) { + auth := ac.Middleware(api.AccessControl) + + scopeKey := ac.Parameter(":key") + settingsScope := ac.Scope("settings", "auth."+scopeKey, "*") + + reqWriteAccess := auth(ac.EvalAny( + ac.EvalPermission(ac.ActionSettingsWrite, ac.ScopeSettingsAuth), + ac.EvalPermission(ac.ActionSettingsWrite, settingsScope))) + + router.Get("/", auth(ac.EvalPermission(ac.ActionSettingsRead, ac.ScopeSettingsAuth)), routing.Wrap(api.listAllProvidersSettings)) + router.Get("/:key", auth(ac.EvalPermission(ac.ActionSettingsRead, settingsScope)), routing.Wrap(api.getProviderSettings)) + router.Put("/:key", reqWriteAccess, routing.Wrap(api.updateProviderSettings)) + router.Delete("/:key", reqWriteAccess, routing.Wrap(api.removeProviderSettings)) + }) +} + +func (api *Api) listAllProvidersSettings(c *contextmodel.ReqContext) response.Response { + providers, err := api.SSOSettingsService.List(c.Req.Context(), c.SignedInUser) + if err != nil { + return response.Error(500, "Failed to get providers", err) + } + + return response.JSON(200, providers) +} + +func (api *Api) getProviderSettings(c *contextmodel.ReqContext) response.Response { + key, ok := web.Params(c.Req)[":key"] + if !ok { + return response.Error(400, "Missing key", nil) + } + + settings, err := api.SSOSettingsService.GetForProvider(c.Req.Context(), key) + if err != nil { + return response.Error(404, "The provider was not found", err) + } + + return response.JSON(200, settings) +} + +func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Response { + key, ok := web.Params(c.Req)[":key"] + if !ok { + return response.Error(400, "Missing key", nil) + } + + var newSettings models.SSOSetting + if err := web.Bind(c.Req, &newSettings); err != nil { + return response.Error(400, "Failed to parse request body", err) + } + + err := api.SSOSettingsService.Upsert(c.Req.Context(), key, newSettings.Settings) + // TODO: first check whether the error is referring to validation errors + + // other error + if err != nil { + return response.Error(500, "Failed to update provider settings", err) + } + + return response.JSON(204, nil) +} + +func (api *Api) removeProviderSettings(c *contextmodel.ReqContext) response.Response { + key, ok := web.Params(c.Req)[":key"] + if !ok { + return response.Error(400, "Missing key", nil) + } + + err := api.SSOSettingsService.Delete(c.Req.Context(), key) + if err != nil { + return response.Error(500, "Failed to delete provider settings", err) + } + + return response.JSON(204, nil) +} diff --git a/pkg/services/ssosettings/api/api_test.go b/pkg/services/ssosettings/api/api_test.go new file mode 100644 index 00000000000..cca1fa9ee63 --- /dev/null +++ b/pkg/services/ssosettings/api/api_test.go @@ -0,0 +1,3 @@ +package api + +// TODO: add tests when you implement the final version of the API endpoint diff --git a/pkg/services/ssosettings/database/database.go b/pkg/services/ssosettings/database/database.go new file mode 100644 index 00000000000..fb239b795e5 --- /dev/null +++ b/pkg/services/ssosettings/database/database.go @@ -0,0 +1,126 @@ +package database + +import ( + "context" + "time" + + "github.com/google/uuid" + "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/services/ssosettings/models" +) + +type SSOSettingsStore struct { + sqlStore db.DB + log log.Logger +} + +func ProvideStore(sqlStore db.DB) *SSOSettingsStore { + return &SSOSettingsStore{ + sqlStore: sqlStore, + log: log.New("ssosettings.store"), + } +} + +var _ ssosettings.Store = (*SSOSettingsStore)(nil) + +func (s *SSOSettingsStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) { + result := models.SSOSetting{Provider: provider} + err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { + var err error + sess.Table("sso_setting") + found, err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Get(&result) + + if err != nil { + return err + } + + if !found { + return ssosettings.ErrNotFound + } + + return nil + }) + + if err != nil { + return nil, err + } + + return &result, nil +} + +func (s *SSOSettingsStore) List(ctx context.Context) ([]*models.SSOSetting, error) { + result := make([]*models.SSOSetting, 0) + err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { + sess.Table("sso_setting") + err := sess.Where("is_deleted = ?", s.sqlStore.GetDialect().BooleanStr(false)).Find(&result) + + if err != nil { + return err + } + + return nil + }) + + if err != nil { + return nil, err + } + + return result, nil +} + +func (s *SSOSettingsStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { + err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { + var err error + found, err := sess.Where("provider = ? AND is_deleted = ?", provider, s.sqlStore.GetDialect().BooleanStr(false)).Exist(&models.SSOSetting{}) + + if err != nil { + return err + } + + if found { + _, err = sess.Where("provider = ? AND is_deleted = ?", provider, s.sqlStore.GetDialect().BooleanStr(false)).Update(&models.SSOSetting{ + Settings: data, + Updated: time.Now().UTC(), + }) + } else { + _, err = sess.Insert(&models.SSOSetting{ + ID: uuid.New().String(), + Provider: provider, + Settings: data, + Created: time.Now().UTC(), + Updated: time.Now().UTC(), + }) + } + + return err + }) + + return err +} + +func (s *SSOSettingsStore) Patch(ctx context.Context, provider string, data map[string]interface{}) error { + panic("not implemented") // TODO: Implement +} + +func (s *SSOSettingsStore) Delete(ctx context.Context, provider string) error { + err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { + existing := new(models.SSOSetting) + found, err := sess.Where("provider = ? AND is_deleted = ?", provider, s.sqlStore.GetDialect().BooleanStr(false)).Get(existing) + if err != nil { + return err + } + + if !found { + return nil // nothing to delete + } + + existing.Updated = time.Now().UTC() + existing.IsDeleted = true + + _, err = sess.ID(existing.ID).MustCols("updated", "is_deleted").Update(existing) + return err + }) + return err +} diff --git a/pkg/services/ssosettings/database/database_test.go b/pkg/services/ssosettings/database/database_test.go new file mode 100644 index 00000000000..6539e8351cb --- /dev/null +++ b/pkg/services/ssosettings/database/database_test.go @@ -0,0 +1,215 @@ +package database + +import ( + "context" + "testing" + + "golang.org/x/exp/maps" + + "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/services/ssosettings/models" + "github.com/stretchr/testify/require" +) + +func TestIntegrationGetSSOSettings(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + var sqlStore *sqlstore.SQLStore + var ssoSettingsStore *SSOSettingsStore + + setup := func() { + sqlStore = db.InitTestDB(t) + ssoSettingsStore = ProvideStore(sqlStore) + + err := insertSSOSetting(ssoSettingsStore, "azuread", nil) + require.NoError(t, err) + } + + t.Run("returns existing SSO settings", func(t *testing.T) { + setup() + + expected := &models.SSOSetting{ + Provider: "azuread", + Settings: map[string]interface{}{ + "enabled": true, + }, + } + + actual, err := ssoSettingsStore.Get(context.Background(), "azuread") + require.NoError(t, err) + + require.True(t, maps.Equal(expected.Settings, actual.Settings)) + }) + + t.Run("returns not found if the SSO setting is missing for the specified provider", func(t *testing.T) { + setup() + + _, err := ssoSettingsStore.Get(context.Background(), "okta") + require.ErrorAs(t, err, &ssosettings.ErrNotFound) + }) + + t.Run("returns not found if the SSO setting is soft deleted for the specified provider", func(t *testing.T) { + setup() + err := ssoSettingsStore.Delete(context.Background(), "azuread") + require.NoError(t, err) + + _, err = ssoSettingsStore.Get(context.Background(), "azuread") + require.ErrorAs(t, err, &ssosettings.ErrNotFound) + }) +} + +func TestIntegrationUpsertSSOSettings(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + var sqlStore *sqlstore.SQLStore + var ssoSettingsStore *SSOSettingsStore + + setup := func() { + sqlStore = db.InitTestDB(t) + ssoSettingsStore = ProvideStore(sqlStore) + } + + t.Run("insert a new SSO setting successfully", func(t *testing.T) { + setup() + + expected := &models.SSOSetting{ + Provider: "azuread", + Settings: map[string]interface{}{ + "enabled": true, + }, + } + + err := ssoSettingsStore.Upsert(context.Background(), "azuread", map[string]interface{}{ + "enabled": true, + }) + require.NoError(t, err) + + actual, err := ssoSettingsStore.Get(context.Background(), "azuread") + require.NoError(t, err) + + require.True(t, maps.Equal(expected.Settings, actual.Settings)) + }) + + t.Run("replaces an existing SSO setting for the specified provider", func(t *testing.T) { + setup() + + err := ssoSettingsStore.Upsert(context.Background(), "azuread", map[string]interface{}{ + "enabled": true, + }) + require.NoError(t, err) + + err = ssoSettingsStore.Upsert(context.Background(), "azuread", map[string]interface{}{ + "enabled": false, + }) + require.NoError(t, err) + + actual, err := ssoSettingsStore.Get(context.Background(), "azuread") + require.NoError(t, err) + + list, err := ssoSettingsStore.List(context.Background()) + require.NoError(t, err) + + require.Equal(t, 1, len(list)) + require.Equal(t, false, actual.Settings["enabled"]) + }) +} + +func TestIntegrationListSSOSettings(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + var sqlStore *sqlstore.SQLStore + var ssoSettingsStore *SSOSettingsStore + + setup := func() { + sqlStore = db.InitTestDB(t) + ssoSettingsStore = ProvideStore(sqlStore) + + err := insertSSOSetting(ssoSettingsStore, "azuread", map[string]interface{}{ + "enabled": true, + }) + require.NoError(t, err) + + err = insertSSOSetting(ssoSettingsStore, "okta", map[string]interface{}{ + "enabled": false, + }) + require.NoError(t, err) + } + + t.Run("returns every SSO settings successfully", func(t *testing.T) { + setup() + + list, err := ssoSettingsStore.List(context.Background()) + + require.NoError(t, err) + require.Equal(t, 2, len(list)) + }) +} + +func TestIntegrationDeleteSSOSettings(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + var sqlStore *sqlstore.SQLStore + var ssoSettingsStore *SSOSettingsStore + + setup := func() { + sqlStore = db.InitTestDB(t) + ssoSettingsStore = ProvideStore(sqlStore) + } + + t.Run("soft deletes the settings successfully", func(t *testing.T) { + setup() + + err := insertSSOSetting(ssoSettingsStore, "azuread", map[string]interface{}{ + "enabled": true, + }) + require.NoError(t, err) + + err = ssoSettingsStore.Delete(context.Background(), "azuread") + + require.NoError(t, err) + + var count int64 + err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error { + count, err = sess.Table("sso_setting").Where("is_deleted = ?", sqlStore.GetDialect().BooleanStr(true)).Count() + return err + }) + require.NoError(t, err) + + require.Equal(t, int64(1), count) + }) + + t.Run("return without error if the integration was not found", func(t *testing.T) { + setup() + + err := ssoSettingsStore.Delete(context.Background(), "azuread") + require.NoError(t, err) + + var count int64 + err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error { + count, err = sess.Table("sso_setting").Where("is_deleted = ?", sqlStore.GetDialect().BooleanStr(true)).Count() + return err + }) + require.NoError(t, err) + + require.Equal(t, int64(0), count) + }) +} + +func insertSSOSetting(ssoSettingsStore ssosettings.Store, provider string, settings map[string]interface{}) error { + if settings == nil { + settings = map[string]interface{}{ + "enabled": true, + } + } + return ssoSettingsStore.Upsert(context.Background(), provider, settings) +} diff --git a/pkg/services/ssosettings/errors.go b/pkg/services/ssosettings/errors.go new file mode 100644 index 00000000000..3bc22582ec0 --- /dev/null +++ b/pkg/services/ssosettings/errors.go @@ -0,0 +1,7 @@ +package ssosettings + +import "errors" + +var ( + ErrNotFound = errors.New("not found") +) diff --git a/pkg/services/ssosettings/models/models.go b/pkg/services/ssosettings/models/models.go new file mode 100644 index 00000000000..89974ef327e --- /dev/null +++ b/pkg/services/ssosettings/models/models.go @@ -0,0 +1,45 @@ +package models + +import ( + "encoding/json" + "fmt" + "time" +) + +type SettingsSource int + +const ( + DB = iota + System +) + +func (s SettingsSource) MarshalJSON() ([]byte, error) { + switch s { + case DB: + return json.Marshal("database") + case System: + return json.Marshal("system") + default: + return nil, fmt.Errorf("unknown source: %d", s) + } +} + +type SSOSetting struct { + ID string `xorm:"id pk" json:"-"` + Provider string `xorm:"provider" json:"provider"` + Settings map[string]interface{} `xorm:"settings" json:"settings"` + Created time.Time `xorm:"created" json:"-"` + Updated time.Time `xorm:"updated" json:"-"` + IsDeleted bool `xorm:"is_deleted" json:"-"` + Source SettingsSource `xorm:"-" json:"source"` +} + +// TableName returns the table name (needed for Xorm) +func (s SSOSetting) TableName() string { + return "sso_setting" +} + +type SSOSettingsResponse struct { + Settings map[string]interface{} `json:"settings"` + Provider string `json:"type"` +} diff --git a/pkg/services/ssosettings/ssosettings.go b/pkg/services/ssosettings/ssosettings.go new file mode 100644 index 00000000000..f7ecc7cbae8 --- /dev/null +++ b/pkg/services/ssosettings/ssosettings.go @@ -0,0 +1,58 @@ +package ssosettings + +import ( + "context" + + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/ssosettings/models" +) + +var ( + // ConfigurableOAuthProviders is a list of OAuth providers that can be configured from the API + // TODO: make it configurable + ConfigurableOAuthProviders = []string{"github", "gitlab", "google", "generic_oauth", "azuread", "okta"} + + AllOAuthProviders = []string{"github", "gitlab", "google", "generic_oauth", "grafana_com", "azuread", "okta"} +) + +// Service is a SSO settings service +type Service interface { + // List returns all SSO settings from DB and config files + List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) + // GetForProvider returns the SSO settings for a given provider (DB or config file) + GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) + // Upsert creates or updates the SSO settings for a given provider + Upsert(ctx context.Context, provider string, data map[string]interface{}) error + // Delete deletes the SSO settings for a given provider (soft delete) + Delete(ctx context.Context, provider string) error + // Patch updates the specified SSO settings (key-value pairs) for a given provider + Patch(ctx context.Context, provider string, data map[string]interface{}) error + // RegisterReloadable registers a reloadable provider + RegisterReloadable(ctx context.Context, provider string, reloadable Reloadable) + // Reload implements ssosettings.Reloadable interface + Reload(ctx context.Context, provider string) +} + +// Reloadable is an interface that can be implemented by a provider to allow it to be reloaded +type Reloadable interface { + Reload(ctx context.Context) error +} + +// FallbackStrategy is an interface that can be implemented to allow a provider to load settings from a different source +// than the database. This is useful for providers that are not configured in the database, but instead are configured +// using the config file and/or environment variables. Used mostly for backwards compatibility. +type FallbackStrategy interface { + IsMatch(provider string) bool + ParseConfigFromSystem(ctx context.Context) (map[string]interface{}, error) +} + +// Store is a SSO settings store +// +//go:generate mockery --name Store --structname MockStore --outpkg ssosettingstests --filename store_mock.go --output ./ssosettingstests/ +type Store interface { + Get(ctx context.Context, provider string) (*models.SSOSetting, error) + List(ctx context.Context) ([]*models.SSOSetting, error) + Upsert(ctx context.Context, provider string, data map[string]interface{}) error + Patch(ctx context.Context, provider string, data map[string]interface{}) error + Delete(ctx context.Context, provider string) error +} diff --git a/pkg/services/ssosettings/ssosettingsimpl/service.go b/pkg/services/ssosettings/ssosettingsimpl/service.go new file mode 100644 index 00000000000..20bbdfe7626 --- /dev/null +++ b/pkg/services/ssosettings/ssosettingsimpl/service.go @@ -0,0 +1,178 @@ +package ssosettingsimpl + +import ( + "context" + "errors" + + "github.com/grafana/grafana/pkg/api/routing" + "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/infra/log" + ac "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/services/ssosettings/api" + "github.com/grafana/grafana/pkg/services/ssosettings/database" + "github.com/grafana/grafana/pkg/services/ssosettings/models" + "github.com/grafana/grafana/pkg/services/ssosettings/strategies" + "github.com/grafana/grafana/pkg/setting" +) + +var _ ssosettings.Service = (*SSOSettingsService)(nil) + +type SSOSettingsService struct { + log log.Logger + cfg *setting.Cfg + store ssosettings.Store + ac ac.AccessControl + fbStrategies []ssosettings.FallbackStrategy +} + +func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl, + routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager) *SSOSettingsService { + strategies := []ssosettings.FallbackStrategy{ + strategies.NewOAuthStrategy(cfg), + // register other strategies here, for example SAML + } + + store := database.ProvideStore(sqlStore) + + svc := &SSOSettingsService{ + log: log.New("ssosettings.service"), + cfg: cfg, + store: store, + ac: ac, + fbStrategies: strategies, + } + + if features.IsEnabled(featuremgmt.FlagSsoSettingsApi) { + ssoSettingsApi := api.ProvideApi(svc, routeRegister, ac) + ssoSettingsApi.RegisterAPIEndpoints() + } + + return svc +} + +var _ ssosettings.Service = (*SSOSettingsService)(nil) + +func (s *SSOSettingsService) GetForProvider(ctx context.Context, provider string) (*models.SSOSetting, error) { + dto, err := s.store.Get(ctx, provider) + + if errors.Is(err, ssosettings.ErrNotFound) { + setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) + if err != nil { + return nil, err + } + + return setting, nil + } + + if err != nil { + return nil, err + } + + dto.Source = models.DB + + return dto, nil +} + +func (s *SSOSettingsService) List(ctx context.Context, requester identity.Requester) ([]*models.SSOSetting, error) { + result := make([]*models.SSOSetting, 0, len(ssosettings.AllOAuthProviders)) + storedSettings, err := s.store.List(ctx) + + if err != nil { + return nil, err + } + + for _, provider := range ssosettings.AllOAuthProviders { + ev := ac.EvalPermission(ac.ActionSettingsRead, ac.Scope("settings", "auth."+provider, "*")) + hasAccess, err := s.ac.Evaluate(ctx, requester, ev) + if err != nil { + return nil, err + } + + if !hasAccess { + continue + } + + settings := getSettingsByProvider(provider, storedSettings) + if len(settings) == 0 { + // If there is no data in the DB then we need to load the settings using the fallback strategy + setting, err := s.loadSettingsUsingFallbackStrategy(ctx, provider) + if err != nil { + return nil, err + } + + settings = append(settings, setting) + } + result = append(result, settings...) + } + + return result, nil +} + +func (s *SSOSettingsService) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { + // TODO: validation (configurable provider? Contains the required fields? etc) + err := s.store.Upsert(ctx, provider, data) + if err != nil { + return err + } + return nil +} + +func (s *SSOSettingsService) Patch(ctx context.Context, provider string, data map[string]interface{}) error { + panic("not implemented") // TODO: Implement +} + +func (s *SSOSettingsService) Delete(ctx context.Context, provider string) error { + return s.store.Delete(ctx, provider) +} + +func (s *SSOSettingsService) Reload(ctx context.Context, provider string) { + panic("not implemented") // TODO: Implement +} + +func (s *SSOSettingsService) RegisterReloadable(ctx context.Context, provider string, reloadable ssosettings.Reloadable) { + panic("not implemented") // TODO: Implement +} + +func (s *SSOSettingsService) RegisterFallbackStrategy(providerRegex string, strategy ssosettings.FallbackStrategy) { + s.fbStrategies = append(s.fbStrategies, strategy) +} + +func (s *SSOSettingsService) loadSettingsUsingFallbackStrategy(ctx context.Context, provider string) (*models.SSOSetting, error) { + loadStrategy, ok := s.getFallBackstrategyFor(provider) + if !ok { + return nil, errors.New("no fallback strategy found for provider: " + provider) + } + + settingsFromSystem, err := loadStrategy.ParseConfigFromSystem(ctx) + if err != nil { + return nil, err + } + + return &models.SSOSetting{ + Provider: provider, + Source: models.System, + Settings: settingsFromSystem, + }, nil +} + +func getSettingsByProvider(provider string, settings []*models.SSOSetting) []*models.SSOSetting { + result := make([]*models.SSOSetting, 0) + for _, setting := range settings { + if setting.Provider == provider { + result = append(result, setting) + } + } + return result +} + +func (s *SSOSettingsService) getFallBackstrategyFor(provider string) (ssosettings.FallbackStrategy, bool) { + for _, strategy := range s.fbStrategies { + if strategy.IsMatch(provider) { + return strategy, true + } + } + return nil, false +} diff --git a/pkg/services/ssosettings/ssosettingsimpl/service_test.go b/pkg/services/ssosettings/ssosettingsimpl/service_test.go new file mode 100644 index 00000000000..a5bd60fe5d1 --- /dev/null +++ b/pkg/services/ssosettings/ssosettingsimpl/service_test.go @@ -0,0 +1,387 @@ +package ssosettingsimpl + +import ( + "context" + "fmt" + "testing" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/services/ssosettings/models" + "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/require" +) + +func TestSSOSettingsService_GetForProvider(t *testing.T) { + testCases := []struct { + name string + setup func(env testEnv) + want *models.SSOSetting + wantErr bool + }{ + { + name: "should return successfully", + setup: func(env testEnv) { + env.store.ExpectedSSOSetting = &models.SSOSetting{ + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + Source: models.DB, + } + }, + want: &models.SSOSetting{ + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + }, + wantErr: false, + }, + { + name: "should return error if store returns an error different than not found", + setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, + want: nil, + wantErr: true, + }, + { + name: "should fallback to strategy if store returns not found", + setup: func(env testEnv) { + env.store.ExpectedError = ssosettings.ErrNotFound + env.fallbackStrategy.ExpectedIsMatch = true + env.fallbackStrategy.ExpectedConfig = map[string]interface{}{ + "enabled": true, + } + }, + want: &models.SSOSetting{ + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + Source: models.System, + }, + wantErr: false, + }, + { + name: "should return error if the fallback strategy was not found", + setup: func(env testEnv) { + env.store.ExpectedError = ssosettings.ErrNotFound + env.fallbackStrategy.ExpectedIsMatch = false + }, + want: nil, + wantErr: true, + }, + { + name: "should return error if fallback strategy returns error", + setup: func(env testEnv) { + env.store.ExpectedError = ssosettings.ErrNotFound + env.fallbackStrategy.ExpectedIsMatch = true + env.fallbackStrategy.ExpectedError = fmt.Errorf("error") + }, + want: nil, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + env := setupTestEnv(t) + if tc.setup != nil { + tc.setup(env) + } + + actual, err := env.service.GetForProvider(context.Background(), "github") + + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.want, actual) + }) + } +} + +func TestSSOSettingsService_List(t *testing.T) { + defaultIdentity := &user.SignedInUser{ + UserID: 1, + OrgID: 1, + Permissions: map[int64]map[string][]string{ + 1: { + accesscontrol.ActionSettingsRead: {accesscontrol.ScopeSettingsAll}, + }, + }, + } + + scopedIdentity := &user.SignedInUser{ + UserID: 1, + OrgID: 1, + Permissions: map[int64]map[string][]string{ + 1: { + accesscontrol.ActionSettingsRead: []string{ + accesscontrol.Scope("settings", "auth.azuread", "*"), + accesscontrol.Scope("settings", "auth.github", "*"), + }, + }, + }, + } + testCases := []struct { + name string + setup func(env testEnv) + identity identity.Requester + want []*models.SSOSetting + wantErr bool + }{ + { + name: "should return successfully", + setup: func(env testEnv) { + env.store.ExpectedSSOSettings = []*models.SSOSetting{ + { + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + Source: models.DB, + }, + { + Provider: "okta", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.DB, + }, + } + env.fallbackStrategy.ExpectedIsMatch = true + env.fallbackStrategy.ExpectedConfig = map[string]interface{}{ + "enabled": false, + } + }, + identity: defaultIdentity, + want: []*models.SSOSetting{ + { + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + Source: models.DB, + }, + { + Provider: "okta", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.DB, + }, + { + Provider: "gitlab", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "generic_oauth", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "google", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "azuread", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "grafana_com", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + }, + wantErr: false, + }, + { + name: "should return the settings that the user has access to", + setup: func(env testEnv) { + env.store.ExpectedSSOSettings = []*models.SSOSetting{ + { + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + Source: models.DB, + }, + { + Provider: "okta", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.DB, + }, + } + env.fallbackStrategy.ExpectedIsMatch = true + env.fallbackStrategy.ExpectedConfig = map[string]interface{}{ + "enabled": false, + } + }, + identity: scopedIdentity, + want: []*models.SSOSetting{ + { + Provider: "github", + Settings: map[string]interface{}{ + "enabled": true, + }, + Source: models.DB, + }, + { + Provider: "azuread", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + }, + wantErr: false, + }, + { + name: "should return error if store returns an error", + setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, + identity: defaultIdentity, + want: nil, + wantErr: true, + }, + { + name: "should use the fallback strategy if store returns empty list", + setup: func(env testEnv) { + env.store.ExpectedSSOSettings = []*models.SSOSetting{} + env.fallbackStrategy.ExpectedIsMatch = true + env.fallbackStrategy.ExpectedConfig = map[string]interface{}{ + "enabled": false, + } + }, + identity: defaultIdentity, + want: []*models.SSOSetting{ + { + Provider: "github", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "okta", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "gitlab", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "generic_oauth", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "google", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "azuread", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + { + Provider: "grafana_com", + Settings: map[string]interface{}{ + "enabled": false, + }, + Source: models.System, + }, + }, + wantErr: false, + }, + { + name: "should return error if any of the fallback strategies was not found", + setup: func(env testEnv) { + env.store.ExpectedSSOSettings = []*models.SSOSetting{} + env.fallbackStrategy.ExpectedIsMatch = false + }, + identity: defaultIdentity, + want: nil, + wantErr: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + env := setupTestEnv(t) + if tc.setup != nil { + tc.setup(env) + } + + actual, err := env.service.List(context.Background(), tc.identity) + + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.ElementsMatch(t, tc.want, actual) + }) + } +} + +func setupTestEnv(t *testing.T) testEnv { + store := ssosettingstests.NewFakeStore() + fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy() + + accessControl := acimpl.ProvideAccessControl(setting.NewCfg()) + svc := &SSOSettingsService{ + log: log.NewNopLogger(), + store: store, + ac: accessControl, + fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy}, + } + return testEnv{ + service: svc, + store: store, + ac: accessControl, + fallbackStrategy: fallbackStrategy, + } +} + +type testEnv struct { + service *SSOSettingsService + store *ssosettingstests.FakeStore + ac accesscontrol.AccessControl + fallbackStrategy *ssosettingstests.FakeFallbackStrategy +} diff --git a/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go b/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go new file mode 100644 index 00000000000..fe50a863115 --- /dev/null +++ b/pkg/services/ssosettings/ssosettingstests/fallback_strategy_fake.go @@ -0,0 +1,22 @@ +package ssosettingstests + +import context "context" + +type FakeFallbackStrategy struct { + ExpectedIsMatch bool + ExpectedConfig map[string]interface{} + + ExpectedError error +} + +func NewFakeFallbackStrategy() *FakeFallbackStrategy { + return &FakeFallbackStrategy{} +} + +func (f *FakeFallbackStrategy) IsMatch(provider string) bool { + return f.ExpectedIsMatch +} + +func (f *FakeFallbackStrategy) ParseConfigFromSystem(ctx context.Context) (map[string]interface{}, error) { + return f.ExpectedConfig, f.ExpectedError +} diff --git a/pkg/services/ssosettings/ssosettingstests/store_fake.go b/pkg/services/ssosettings/ssosettingstests/store_fake.go new file mode 100644 index 00000000000..9642c35e33c --- /dev/null +++ b/pkg/services/ssosettings/ssosettingstests/store_fake.go @@ -0,0 +1,40 @@ +package ssosettingstests + +import ( + context "context" + + "github.com/grafana/grafana/pkg/services/ssosettings" + models "github.com/grafana/grafana/pkg/services/ssosettings/models" +) + +var _ ssosettings.Store = (*FakeStore)(nil) + +type FakeStore struct { + ExpectedSSOSetting *models.SSOSetting + ExpectedSSOSettings []*models.SSOSetting + ExpectedError error +} + +func NewFakeStore() *FakeStore { + return &FakeStore{} +} + +func (f *FakeStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) { + return f.ExpectedSSOSetting, f.ExpectedError +} + +func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSetting, error) { + return f.ExpectedSSOSettings, f.ExpectedError +} + +func (f *FakeStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { + return f.ExpectedError +} + +func (f *FakeStore) Patch(ctx context.Context, provider string, data map[string]interface{}) error { + return f.ExpectedError +} + +func (f *FakeStore) Delete(ctx context.Context, provider string) error { + return f.ExpectedError +} diff --git a/pkg/services/ssosettings/ssosettingstests/store_mock.go b/pkg/services/ssosettings/ssosettingstests/store_mock.go new file mode 100644 index 00000000000..9683c7b7d67 --- /dev/null +++ b/pkg/services/ssosettings/ssosettingstests/store_mock.go @@ -0,0 +1,124 @@ +// Code generated by mockery v2.27.1. DO NOT EDIT. + +package ssosettingstests + +import ( + context "context" + + models "github.com/grafana/grafana/pkg/services/ssosettings/models" + mock "github.com/stretchr/testify/mock" +) + +// MockStore is an autogenerated mock type for the Store type +type MockStore struct { + mock.Mock +} + +// Delete provides a mock function with given fields: ctx, provider +func (_m *MockStore) Delete(ctx context.Context, provider string) error { + ret := _m.Called(ctx, provider) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, provider) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: ctx, provider +func (_m *MockStore) Get(ctx context.Context, provider string) (*models.SSOSetting, error) { + ret := _m.Called(ctx, provider) + + var r0 *models.SSOSetting + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.SSOSetting, error)); ok { + return rf(ctx, provider) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *models.SSOSetting); ok { + r0 = rf(ctx, provider) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.SSOSetting) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, provider) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: ctx +func (_m *MockStore) GetAll(ctx context.Context) ([]*models.SSOSetting, error) { + ret := _m.Called(ctx) + + var r0 []*models.SSOSetting + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*models.SSOSetting, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*models.SSOSetting); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.SSOSetting) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Patch provides a mock function with given fields: ctx, provider, data +func (_m *MockStore) Patch(ctx context.Context, provider string, data map[string]interface{}) error { + ret := _m.Called(ctx, provider, data) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok { + r0 = rf(ctx, provider, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Upsert provides a mock function with given fields: ctx, provider, data +func (_m *MockStore) Upsert(ctx context.Context, provider string, data map[string]interface{}) error { + ret := _m.Called(ctx, provider, data) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, map[string]interface{}) error); ok { + r0 = rf(ctx, provider, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewMockStore interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockStore(t mockConstructorTestingTNewMockStore) *MockStore { + mock := &MockStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/services/ssosettings/strategies/oauth_strategy.go b/pkg/services/ssosettings/strategies/oauth_strategy.go new file mode 100644 index 00000000000..114062d6e55 --- /dev/null +++ b/pkg/services/ssosettings/strategies/oauth_strategy.go @@ -0,0 +1,71 @@ +package strategies + +import ( + "context" + "regexp" + "strings" + + "github.com/grafana/grafana/pkg/services/ssosettings" + "github.com/grafana/grafana/pkg/setting" +) + +type OAuthStrategy struct { + provider string + cfg *setting.Cfg + supportedProvidersRegex *regexp.Regexp +} + +var _ ssosettings.FallbackStrategy = (*OAuthStrategy)(nil) + +func NewOAuthStrategy(cfg *setting.Cfg) *OAuthStrategy { + compiledRegex := regexp.MustCompile(`^` + strings.Join(ssosettings.AllOAuthProviders, "|") + `$`) + return &OAuthStrategy{ + cfg: cfg, + supportedProvidersRegex: compiledRegex, + } +} + +func (s *OAuthStrategy) IsMatch(provider string) bool { + return s.supportedProvidersRegex.MatchString(provider) +} + +func (s *OAuthStrategy) ParseConfigFromSystem(_ context.Context) (map[string]interface{}, error) { + section := s.cfg.SectionWithEnvOverrides("auth." + s.provider) + + result := map[string]interface{}{ + "client_id": section.Key("client_id").Value(), + "client_secret": section.Key("client_secret").Value(), + "scopes": section.Key("scopes").Value(), + "auth_url": section.Key("auth_url").Value(), + "token_url": section.Key("token_url").Value(), + "api_url": section.Key("api_url").Value(), + "teams_url": section.Key("teams_url").Value(), + "enabled": section.Key("enabled").MustBool(false), + "email_attribute_name": section.Key("email_attribute_name").Value(), + "email_attribute_path": section.Key("email_attribute_path").Value(), + "role_attribute_path": section.Key("role_attribute_path").Value(), + "role_attribute_strict": section.Key("role_attribute_strict").MustBool(false), + "groups_attribute_path": section.Key("groups_attribute_path").Value(), + "team_ids_attribute_path": section.Key("team_ids_attribute_path").Value(), + "allowed_domains": section.Key("allowed_domains").Value(), + "hosted_domain": section.Key("hosted_domain").Value(), + "allow_sign_up": section.Key("allow_sign_up").MustBool(true), + "name": section.Key("name").MustString("default name"), // TODO: change this default value + "icon": section.Key("icon").Value(), + "tls_client_cert": section.Key("tls_client_cert").Value(), + "tls_client_key": section.Key("tls_client_key").Value(), + "tls_client_ca": section.Key("tls_client_ca").Value(), + "tls_skip_verify_insecure": section.Key("tls_skip_verify_insecure").MustBool(false), + "use_pkce": section.Key("use_pkce").MustBool(true), + "use_refresh_token": section.Key("use_refresh_token").MustBool(false), + "allow_assign_grafana_admin": section.Key("allow_assign_grafana_admin").MustBool(false), + "auto_login": section.Key("auto_login").MustBool(false), + "allowed_groups": section.Key("allowed_groups").Value(), + } + + // when empty_scopes parameter exists and is true, overwrite scope with empty value + if section.Key("empty_scopes").MustBool(false) { + result["scopes"] = []string{} + } + return result, nil +}