diff --git a/pkg/api/admin_users_test.go b/pkg/api/admin_users_test.go index 2d95eee079c..7b2bdbb37c6 100644 --- a/pkg/api/admin_users_test.go +++ b/pkg/api/admin_users_test.go @@ -42,6 +42,14 @@ func (m *mockAuthInfoService) GetAuthInfo(ctx context.Context, query *models.Get return m.ExpectedError } +func (m *mockAuthInfoService) SetAuthInfo(ctx context.Context, query *models.SetAuthInfoCommand) error { + return m.ExpectedError +} + +func (m *mockAuthInfoService) UpdateAuthInfo(ctx context.Context, query *models.UpdateAuthInfoCommand) error { + return m.ExpectedError +} + func TestAdminAPIEndpoint(t *testing.T) { const role = models.ROLE_ADMIN diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 8060092a4a6..82d1e76fbbb 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -12,7 +12,6 @@ import ( "net/url" "github.com/grafana/grafana/pkg/api/response" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/login" @@ -235,7 +234,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) response.Response { } loginInfo.ExternalUser = *buildExternalUserInfo(token, userInfo, name) - loginInfo.User, err = syncUser(ctx, &loginInfo.ExternalUser, connect) + loginInfo.User, err = hs.SyncUser(ctx, &loginInfo.ExternalUser, connect) if err != nil { hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err) return nil @@ -300,8 +299,8 @@ func buildExternalUserInfo(token *oauth2.Token, userInfo *social.BasicUserInfo, return extUser } -// syncUser syncs a Grafana user profile with the corresponding OAuth profile. -func syncUser( +// SyncUser syncs a Grafana user profile with the corresponding OAuth profile. +func (hs *HTTPServer) SyncUser( ctx *models.ReqContext, extUser *models.ExternalUserInfo, connect social.SocialConnector, @@ -313,7 +312,8 @@ func syncUser( ExternalUser: extUser, SignupAllowed: connect.IsSignupAllowed(), } - if err := bus.Dispatch(ctx.Req.Context(), cmd); err != nil { + + if err := hs.Login.UpsertUser(ctx.Req.Context(), cmd); err != nil { return nil, err } diff --git a/pkg/services/login/authinfo.go b/pkg/services/login/authinfo.go index 447ca6f6afc..7ada13996e7 100644 --- a/pkg/services/login/authinfo.go +++ b/pkg/services/login/authinfo.go @@ -9,4 +9,6 @@ import ( type AuthInfoService interface { LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*models.User, error) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error + SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error + UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error } diff --git a/pkg/services/login/authinfoservice/database/database.go b/pkg/services/login/authinfoservice/database/database.go index 777f9a7f8ab..5fd0079b1fe 100644 --- a/pkg/services/login/authinfoservice/database/database.go +++ b/pkg/services/login/authinfoservice/database/database.go @@ -42,7 +42,7 @@ func (s *AuthInfoStore) registerBusHandlers() { func (s *AuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error { userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail} - err := s.bus.Dispatch(ctx, &userQuery) + err := s.sqlStore.GetUserByLogin(ctx, &userQuery) if err != nil { return err } diff --git a/pkg/services/login/authinfoservice/service.go b/pkg/services/login/authinfoservice/service.go index b72c257d434..5bcbda734f1 100644 --- a/pkg/services/login/authinfoservice/service.go +++ b/pkg/services/login/authinfoservice/service.go @@ -183,3 +183,7 @@ func (s *Implementation) GetAuthInfo(ctx context.Context, query *models.GetAuthI func (s *Implementation) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error { return s.authInfoStore.UpdateAuthInfo(ctx, cmd) } + +func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error { + return s.authInfoStore.SetAuthInfo(ctx, cmd) +} diff --git a/pkg/services/login/loginservice/loginservice.go b/pkg/services/login/loginservice/loginservice.go index 6c088feb139..e567014a410 100644 --- a/pkg/services/login/loginservice/loginservice.go +++ b/pkg/services/login/loginservice/loginservice.go @@ -28,7 +28,7 @@ func ProvideService(sqlStore *sqlstore.SQLStore, bus bus.Bus, quotaService *quot } type Implementation struct { - SQLStore *sqlstore.SQLStore + SQLStore sqlstore.Store Bus bus.Bus AuthInfoService login.AuthInfoService QuotaService *quota.QuotaService @@ -81,21 +81,21 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUser AuthId: extUser.AuthId, OAuthToken: extUser.OAuthToken, } - if err := ls.Bus.Dispatch(ctx, cmd2); err != nil { + if err := ls.AuthInfoService.SetAuthInfo(ctx, cmd2); err != nil { return err } } } else { cmd.Result = user - err = updateUser(ctx, cmd.Result, extUser) + err = ls.updateUser(ctx, cmd.Result, extUser) if err != nil { return err } // Always persist the latest token at log-in if extUser.AuthModule != "" && extUser.OAuthToken != nil { - err = updateUserAuth(ctx, cmd.Result, extUser) + err = ls.updateUserAuth(ctx, cmd.Result, extUser) if err != nil { return err } @@ -103,13 +103,13 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUser if extUser.AuthModule == models.AuthModuleLDAP && user.IsDisabled { // Re-enable user when it found in LDAP - if err := ls.Bus.Dispatch(ctx, &models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil { + if err := ls.SQLStore.DisableUser(ctx, &models.DisableUserCommand{UserId: cmd.Result.Id, IsDisabled: false}); err != nil { return err } } } - if err := syncOrgRoles(ctx, cmd.Result, extUser); err != nil { + if err := ls.syncOrgRoles(ctx, cmd.Result, extUser); err != nil { return err } @@ -146,7 +146,7 @@ func (ls *Implementation) createUser(extUser *models.ExternalUserInfo) (*models. return ls.CreateUser(cmd) } -func updateUser(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error { +func (ls *Implementation) updateUser(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error { // sync user info updateCmd := &models.UpdateUserCommand{ UserId: user.Id, @@ -176,10 +176,10 @@ func updateUser(ctx context.Context, user *models.User, extUser *models.External } logger.Debug("Syncing user info", "id", user.Id, "update", updateCmd) - return bus.Dispatch(ctx, updateCmd) + return ls.SQLStore.UpdateUser(ctx, updateCmd) } -func updateUserAuth(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error { +func (ls *Implementation) updateUserAuth(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error { updateCmd := &models.UpdateAuthInfoCommand{ AuthModule: extUser.AuthModule, AuthId: extUser.AuthId, @@ -188,10 +188,10 @@ func updateUserAuth(ctx context.Context, user *models.User, extUser *models.Exte } logger.Debug("Updating user_auth info", "user_id", user.Id) - return bus.Dispatch(ctx, updateCmd) + return ls.AuthInfoService.UpdateAuthInfo(ctx, updateCmd) } -func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error { +func (ls *Implementation) syncOrgRoles(ctx context.Context, user *models.User, extUser *models.ExternalUserInfo) error { logger.Debug("Syncing organization roles", "id", user.Id, "extOrgRoles", extUser.OrgRoles) // don't sync org roles if none is specified @@ -201,7 +201,7 @@ func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.Extern } orgsQuery := &models.GetUserOrgListQuery{UserId: user.Id} - if err := bus.Dispatch(ctx, orgsQuery); err != nil { + if err := ls.SQLStore.GetUserOrgList(ctx, orgsQuery); err != nil { return err } @@ -218,7 +218,7 @@ func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.Extern } else if extRole != org.Role { // update role cmd := &models.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extRole} - if err := bus.Dispatch(ctx, cmd); err != nil { + if err := ls.SQLStore.UpdateOrgUser(ctx, cmd); err != nil { return err } } @@ -232,7 +232,7 @@ func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.Extern // add role cmd := &models.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId} - err := bus.Dispatch(ctx, cmd) + err := ls.SQLStore.AddOrgUser(ctx, cmd) if err != nil && !errors.Is(err, models.ErrOrgNotFound) { return err } @@ -243,7 +243,7 @@ func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.Extern logger.Debug("Removing user's organization membership as part of syncing with OAuth login", "userId", user.Id, "orgId", orgId) cmd := &models.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id} - if err := bus.Dispatch(ctx, cmd); err != nil { + if err := ls.SQLStore.RemoveOrgUser(ctx, cmd); err != nil { if errors.Is(err, models.ErrLastOrgAdmin) { logger.Error(err.Error(), "userId", cmd.UserId, "orgId", cmd.OrgId) continue @@ -260,7 +260,7 @@ func syncOrgRoles(ctx context.Context, user *models.User, extUser *models.Extern break } - return bus.Dispatch(ctx, &models.SetUsingOrgCommand{ + return ls.SQLStore.SetUsingOrg(ctx, &models.SetUsingOrgCommand{ UserId: user.Id, OrgId: user.OrgId, }) diff --git a/pkg/services/login/loginservice/loginservice_test.go b/pkg/services/login/loginservice/loginservice_test.go index 480c6376fc8..aef102c6193 100644 --- a/pkg/services/login/loginservice/loginservice_test.go +++ b/pkg/services/login/loginservice/loginservice_test.go @@ -11,6 +11,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log/level" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/quota" + "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,29 +19,21 @@ import ( func Test_syncOrgRoles_doesNotBreakWhenTryingToRemoveLastOrgAdmin(t *testing.T) { user := createSimpleUser() externalUser := createSimpleExternalUser() - remResp := createResponseWithOneErrLastOrgAdminItem() + authInfoMock := &authInfoServiceMock{} - bus.ClearBusHandlers() - defer bus.ClearBusHandlers() - bus.AddHandler("test", func(ctx context.Context, q *models.GetUserOrgListQuery) error { - q.Result = createUserOrgDTO() + store := &mockstore.SQLStoreMock{ + ExpectedUserOrgList: createUserOrgDTO(), + ExpectedOrgListResponse: createResponseWithOneErrLastOrgAdminItem(), + } - return nil - }) + login := Implementation{ + Bus: bus.New(), + QuotaService: "a.QuotaService{}, + AuthInfoService: authInfoMock, + SQLStore: store, + } - bus.AddHandler("test", func(ctx context.Context, cmd *models.RemoveOrgUserCommand) error { - testData := remResp[0] - remResp = remResp[1:] - - require.Equal(t, testData.orgId, cmd.OrgId) - return testData.response - }) - bus.AddHandler("test", func(ctx context.Context, cmd *models.SetUsingOrgCommand) error { - return nil - }) - - err := syncOrgRoles(context.Background(), &user, &externalUser) - require.Empty(t, remResp) + err := login.syncOrgRoles(context.Background(), &user, &externalUser) require.NoError(t, err) } @@ -50,27 +43,22 @@ func Test_syncOrgRoles_whenTryingToRemoveLastOrgLogsError(t *testing.T) { user := createSimpleUser() externalUser := createSimpleExternalUser() - remResp := createResponseWithOneErrLastOrgAdminItem() - bus.ClearBusHandlers() - defer bus.ClearBusHandlers() - bus.AddHandler("test", func(ctx context.Context, q *models.GetUserOrgListQuery) error { - q.Result = createUserOrgDTO() - return nil - }) + authInfoMock := &authInfoServiceMock{} - bus.AddHandler("test", func(ctx context.Context, cmd *models.RemoveOrgUserCommand) error { - testData := remResp[0] - remResp = remResp[1:] + store := &mockstore.SQLStoreMock{ + ExpectedUserOrgList: createUserOrgDTO(), + ExpectedOrgListResponse: createResponseWithOneErrLastOrgAdminItem(), + } - require.Equal(t, testData.orgId, cmd.OrgId) - return testData.response - }) - bus.AddHandler("test", func(ctx context.Context, cmd *models.SetUsingOrgCommand) error { - return nil - }) + login := Implementation{ + Bus: bus.New(), + QuotaService: "a.QuotaService{}, + AuthInfoService: authInfoMock, + SQLStore: store, + } - err := syncOrgRoles(context.Background(), &user, &externalUser) + err := login.syncOrgRoles(context.Background(), &user, &externalUser) require.NoError(t, err) assert.Contains(t, buf.String(), models.ErrLastOrgAdmin.Error()) } @@ -88,11 +76,18 @@ func (a *authInfoServiceMock) GetAuthInfo(ctx context.Context, query *models.Get return nil } +func (a *authInfoServiceMock) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error { + return nil +} + +func (a *authInfoServiceMock) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error { + return nil +} + func Test_teamSync(t *testing.T) { - b := bus.New() authInfoMock := &authInfoServiceMock{} login := Implementation{ - Bus: b, + Bus: bus.New(), QuotaService: "a.QuotaService{}, AuthInfoService: authInfoMock, } @@ -181,21 +176,15 @@ func createSimpleExternalUser() models.ExternalUserInfo { return externalUser } -func createResponseWithOneErrLastOrgAdminItem() []struct { - orgId int64 - response error -} { - remResp := []struct { - orgId int64 - response error - }{ +func createResponseWithOneErrLastOrgAdminItem() mockstore.OrgListResponse { + remResp := mockstore.OrgListResponse{ { - orgId: 10, - response: models.ErrLastOrgAdmin, + OrgId: 10, + Response: models.ErrLastOrgAdmin, }, { - orgId: 11, - response: nil, + OrgId: 11, + Response: nil, }, } return remResp diff --git a/pkg/services/sqlstore/mockstore/mockstore.go b/pkg/services/sqlstore/mockstore/mockstore.go index e26542633d9..80a50b293ec 100644 --- a/pkg/services/sqlstore/mockstore/mockstore.go +++ b/pkg/services/sqlstore/mockstore/mockstore.go @@ -8,10 +8,13 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" ) +type OrgListResponse []struct { + OrgId int64 + Response error +} type SQLStoreMock struct { - LastGetAlertsQuery *models.GetAlertsQuery - LatestUserId int64 - + LastGetAlertsQuery *models.GetAlertsQuery + LatestUserId int64 ExpectedUser *models.User ExpectedDatasource *models.DataSource ExpectedAlert *models.Alert @@ -20,8 +23,9 @@ type SQLStoreMock struct { ExpectedDashboards []*models.Dashboard ExpectedDashboardVersion *models.DashboardVersion ExpectedDashboardAclInfoList []*models.DashboardAclInfoDTO - - ExpectedError error + ExpectedUserOrgList []*models.UserOrgDTO + ExpectedOrgListResponse OrgListResponse + ExpectedError error } func NewSQLStoreMock() *SQLStoreMock { @@ -150,6 +154,7 @@ func (m *SQLStoreMock) GetUserProfile(ctx context.Context, query *models.GetUser } func (m *SQLStoreMock) GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error { + query.Result = m.ExpectedUserOrgList return m.ExpectedError } @@ -424,7 +429,9 @@ func (m *SQLStoreMock) SearchOrgUsers(ctx context.Context, query *models.SearchO } func (m *SQLStoreMock) RemoveOrgUser(ctx context.Context, cmd *models.RemoveOrgUserCommand) error { - return m.ExpectedError + testData := m.ExpectedOrgListResponse[0] + m.ExpectedOrgListResponse = m.ExpectedOrgListResponse[1:] + return testData.Response } func (m *SQLStoreMock) SaveDashboard(cmd models.SaveDashboardCommand) (*models.Dashboard, error) {