From 92ca38bedf18d712dddddeffab061df68d8d3988 Mon Sep 17 00:00:00 2001 From: Katarina Yang <69819079+yangkb09@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:30:08 -0500 Subject: [PATCH] Refactor: Change sqlstore.inTransaction to SQLStore.WithTransactionalDBSession in misc files (#43926) * Refactor: Change sqlstore.inTransaction to SQLStore.WithTransactionalDBSession in misc files * Refactor: Change .inTransaction in org.go file * Refactor: Update init() to proper SQLStore handlers * Refactor: Update funcs in tests to be sqlStore methods * Refactor: Update API funcs to receive HTTPServer * Fix: define methods on sqlstore * Adjust GetSignedInUser calls * Refactor: Add sqlStore to Service struct * Chore: Add back black spaces to remove file from PR Co-authored-by: Ida Furjesova --- pkg/api/api.go | 10 +++---- pkg/api/org.go | 29 +++++++++---------- pkg/api/org_users_test.go | 2 +- .../resourcepermissions/service.go | 2 +- pkg/services/sqlstore/alert_test.go | 2 +- pkg/services/sqlstore/dashboard.go | 6 ++-- .../sqlstore/dashboard_provisioning.go | 8 ++--- .../sqlstore/dashboard_provisioning_test.go | 4 +-- pkg/services/sqlstore/dashboard_test.go | 8 ++--- pkg/services/sqlstore/login_attempt.go | 14 ++++----- pkg/services/sqlstore/login_attempt_test.go | 17 ++++++----- pkg/services/sqlstore/org.go | 20 ++++++------- pkg/services/sqlstore/org_test.go | 10 +++---- pkg/services/sqlstore/sqlstore.go | 3 ++ pkg/services/sqlstore/team.go | 28 +++++++++--------- pkg/services/sqlstore/team_test.go | 2 +- pkg/services/sqlstore/user.go | 6 ++-- 17 files changed, 88 insertions(+), 83 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 1d566c94d45..4ae8e13bb13 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -206,8 +206,8 @@ func (hs *HTTPServer) registerRoutes() { // current org apiRoute.Group("/org", func(orgRoute routing.RouteRegister) { userIDScope := ac.Scope("users", "id", ac.Parameter(":userId")) - orgRoute.Put("/", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateCurrentOrg)) - orgRoute.Put("/address", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateCurrentOrgAddress)) + orgRoute.Put("/", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateCurrentOrg)) + orgRoute.Put("/address", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateCurrentOrgAddress)) orgRoute.Get("/users", authorize(reqOrgAdmin, ac.EvalPermission(ac.ActionOrgUsersRead)), routing.Wrap(hs.GetOrgUsersForCurrentOrg)) orgRoute.Get("/users/search", authorize(reqOrgAdmin, ac.EvalPermission(ac.ActionOrgUsersRead)), routing.Wrap(hs.SearchOrgUsersWithPaging)) orgRoute.Post("/users", authorize(reqOrgAdmin, ac.EvalPermission(ac.ActionOrgUsersAdd, ac.ScopeUsersAll)), quota("user"), routing.Wrap(hs.AddOrgUserToCurrentOrg)) @@ -239,9 +239,9 @@ func (hs *HTTPServer) registerRoutes() { apiRoute.Group("/orgs/:orgId", func(orgsRoute routing.RouteRegister) { userIDScope := ac.Scope("users", "id", ac.Parameter(":userId")) orgsRoute.Get("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsRead)), routing.Wrap(GetOrgByID)) - orgsRoute.Put("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateOrg)) - orgsRoute.Put("/address", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateOrgAddress)) - orgsRoute.Delete("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsDelete)), routing.Wrap(DeleteOrgByID)) + orgsRoute.Put("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateOrg)) + orgsRoute.Put("/address", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateOrgAddress)) + orgsRoute.Delete("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsDelete)), routing.Wrap(hs.DeleteOrgByID)) orgsRoute.Get("/users", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ac.ActionOrgUsersRead, ac.ScopeUsersAll)), routing.Wrap(hs.GetOrgUsers)) orgsRoute.Post("/users", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ac.ActionOrgUsersAdd, ac.ScopeUsersAll)), routing.Wrap(hs.AddOrgUser)) orgsRoute.Patch("/users/:userId", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ac.ActionOrgUsersRoleUpdate, userIDScope)), routing.Wrap(hs.UpdateOrgUser)) diff --git a/pkg/api/org.go b/pkg/api/org.go index d2f2d6e6e22..3ab227bb347 100644 --- a/pkg/api/org.go +++ b/pkg/api/org.go @@ -111,31 +111,30 @@ func (hs *HTTPServer) CreateOrg(c *models.ReqContext) response.Response { } // PUT /api/org -func UpdateCurrentOrg(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateCurrentOrg(c *models.ReqContext) response.Response { form := dtos.UpdateOrgForm{} if err := web.Bind(c.Req, &form); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) } - return updateOrgHelper(c.Req.Context(), form, c.OrgId) + return hs.updateOrgHelper(c.Req.Context(), form, c.OrgId) } // PUT /api/orgs/:orgId -func UpdateOrg(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateOrg(c *models.ReqContext) response.Response { form := dtos.UpdateOrgForm{} if err := web.Bind(c.Req, &form); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) } - orgId, err := strconv.ParseInt(web.Params(c.Req)[":orgId"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "orgId is invalid", err) } - return updateOrgHelper(c.Req.Context(), form, orgId) + return hs.updateOrgHelper(c.Req.Context(), form, orgId) } -func updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64) response.Response { +func (hs *HTTPServer) updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64) response.Response { cmd := models.UpdateOrgCommand{Name: form.Name, OrgId: orgID} - if err := sqlstore.UpdateOrg(ctx, &cmd); err != nil { + if err := hs.SQLStore.UpdateOrg(ctx, &cmd); err != nil { if errors.Is(err, models.ErrOrgNameTaken) { return response.Error(400, "Organization name taken", err) } @@ -146,16 +145,16 @@ func updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64) } // PUT /api/org/address -func UpdateCurrentOrgAddress(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateCurrentOrgAddress(c *models.ReqContext) response.Response { form := dtos.UpdateOrgAddressForm{} if err := web.Bind(c.Req, &form); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) } - return updateOrgAddressHelper(c.Req.Context(), form, c.OrgId) + return hs.updateOrgAddressHelper(c.Req.Context(), form, c.OrgId) } // PUT /api/orgs/:orgId/address -func UpdateOrgAddress(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateOrgAddress(c *models.ReqContext) response.Response { form := dtos.UpdateOrgAddressForm{} if err := web.Bind(c.Req, &form); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) @@ -164,10 +163,10 @@ func UpdateOrgAddress(c *models.ReqContext) response.Response { if err != nil { return response.Error(http.StatusBadRequest, "orgId is invalid", err) } - return updateOrgAddressHelper(c.Req.Context(), form, orgId) + return hs.updateOrgAddressHelper(c.Req.Context(), form, orgId) } -func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, orgID int64) response.Response { +func (hs *HTTPServer) updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, orgID int64) response.Response { cmd := models.UpdateOrgAddressCommand{ OrgId: orgID, Address: models.Address{ @@ -180,7 +179,7 @@ func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, }, } - if err := sqlstore.UpdateOrgAddress(ctx, &cmd); err != nil { + if err := hs.SQLStore.UpdateOrgAddress(ctx, &cmd); err != nil { return response.Error(500, "Failed to update org address", err) } @@ -188,7 +187,7 @@ func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, } // DELETE /api/orgs/:orgId -func DeleteOrgByID(c *models.ReqContext) response.Response { +func (hs *HTTPServer) DeleteOrgByID(c *models.ReqContext) response.Response { orgID, err := strconv.ParseInt(web.Params(c.Req)[":orgId"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "orgId is invalid", err) @@ -198,7 +197,7 @@ func DeleteOrgByID(c *models.ReqContext) response.Response { return response.Error(400, "Can not delete org for current user", nil) } - if err := sqlstore.DeleteOrg(c.Req.Context(), &models.DeleteOrgCommand{Id: orgID}); err != nil { + if err := hs.SQLStore.DeleteOrg(c.Req.Context(), &models.DeleteOrgCommand{Id: orgID}); err != nil { if errors.Is(err, models.ErrOrgNotFound) { return response.Error(404, "Failed to delete organization. ID not found", nil) } diff --git a/pkg/api/org_users_test.go b/pkg/api/org_users_test.go index a208fec07e1..1330d8cb60d 100644 --- a/pkg/api/org_users_test.go +++ b/pkg/api/org_users_test.go @@ -690,7 +690,7 @@ func TestPatchOrgUsersAPIEndpoint_AccessControl(t *testing.T) { UserId: tc.targetUserId, OrgId: tc.targetOrg, } - err = sqlstore.GetSignedInUser(context.Background(), &getUserQuery) + err = sc.db.GetSignedInUser(context.Background(), &getUserQuery) require.NoError(t, err) assert.Equal(t, tc.expectedUserRole, getUserQuery.Result.OrgRole) } diff --git a/pkg/services/accesscontrol/resourcepermissions/service.go b/pkg/services/accesscontrol/resourcepermissions/service.go index 5ec3d53ee5d..0914ae89682 100644 --- a/pkg/services/accesscontrol/resourcepermissions/service.go +++ b/pkg/services/accesscontrol/resourcepermissions/service.go @@ -207,7 +207,7 @@ func (s *Service) validateResource(ctx context.Context, orgID int64, resourceID } func (s *Service) validateUser(ctx context.Context, orgID, userID int64) error { - if err := sqlstore.GetSignedInUser(ctx, &models.GetSignedInUserQuery{OrgId: orgID, UserId: userID}); err != nil { + if err := s.sqlStore.GetSignedInUser(ctx, &models.GetSignedInUserQuery{OrgId: orgID, UserId: userID}); err != nil { return err } return nil diff --git a/pkg/services/sqlstore/alert_test.go b/pkg/services/sqlstore/alert_test.go index 1f7d1fdb5e5..9430fa63819 100644 --- a/pkg/services/sqlstore/alert_test.go +++ b/pkg/services/sqlstore/alert_test.go @@ -245,7 +245,7 @@ func TestAlertingDataAccess(t *testing.T) { err := sqlStore.SaveAlerts(context.Background(), testDash.Id, items) require.Nil(t, err) - err = DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{ + err = sqlStore.DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{ OrgId: 1, Id: testDash.Id, }) diff --git a/pkg/services/sqlstore/dashboard.go b/pkg/services/sqlstore/dashboard.go index 46c362e8dab..b498482e434 100644 --- a/pkg/services/sqlstore/dashboard.go +++ b/pkg/services/sqlstore/dashboard.go @@ -29,7 +29,6 @@ var shadowSearchCounter = prometheus.NewCounterVec( func init() { bus.AddHandler("sql", GetDashboard) bus.AddHandler("sql", GetDashboards) - bus.AddHandler("sql", DeleteDashboard) bus.AddHandler("sql", GetDashboardTags) bus.AddHandler("sql", GetDashboardSlugById) bus.AddHandler("sql", GetDashboardsByPluginId) @@ -44,6 +43,7 @@ func init() { func (ss *SQLStore) addDashboardQueryAndCommandHandlers() { bus.AddHandler("sql", ss.GetDashboardUIDById) bus.AddHandler("sql", ss.SearchDashboards) + bus.AddHandler("sql", ss.DeleteDashboard) } var generateNewUid func() string = util.GenerateShortUID @@ -410,8 +410,8 @@ func GetDashboardTags(ctx context.Context, query *models.GetDashboardTagsQuery) return err } -func DeleteDashboard(ctx context.Context, cmd *models.DeleteDashboardCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) DeleteDashboard(ctx context.Context, cmd *models.DeleteDashboardCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { return deleteDashboard(cmd, sess) }) } diff --git a/pkg/services/sqlstore/dashboard_provisioning.go b/pkg/services/sqlstore/dashboard_provisioning.go index 048fd76476d..053f5ad74c8 100644 --- a/pkg/services/sqlstore/dashboard_provisioning.go +++ b/pkg/services/sqlstore/dashboard_provisioning.go @@ -8,9 +8,9 @@ import ( "github.com/grafana/grafana/pkg/models" ) -func init() { +func (ss *SQLStore) addDashboardProvisioningQueryAndCommandHandlers() { bus.AddHandler("sql", UnprovisionDashboard) - bus.AddHandler("sql", DeleteOrphanedProvisionedDashboards) + bus.AddHandler("sql", ss.DeleteOrphanedProvisionedDashboards) } type DashboardExtras struct { @@ -111,7 +111,7 @@ func UnprovisionDashboard(ctx context.Context, cmd *models.UnprovisionDashboardC return nil } -func DeleteOrphanedProvisionedDashboards(ctx context.Context, cmd *models.DeleteOrphanedProvisionedDashboardsCommand) error { +func (ss *SQLStore) DeleteOrphanedProvisionedDashboards(ctx context.Context, cmd *models.DeleteOrphanedProvisionedDashboardsCommand) error { var result []*models.DashboardProvisioning convertedReaderNames := make([]interface{}, len(cmd.ReaderNames)) @@ -125,7 +125,7 @@ func DeleteOrphanedProvisionedDashboards(ctx context.Context, cmd *models.Delete } for _, deleteDashCommand := range result { - err := DeleteDashboard(ctx, &models.DeleteDashboardCommand{Id: deleteDashCommand.DashboardId}) + err := ss.DeleteDashboard(ctx, &models.DeleteDashboardCommand{Id: deleteDashCommand.DashboardId}) if err != nil && !errors.Is(err, models.ErrDashboardNotFound) { return err } diff --git a/pkg/services/sqlstore/dashboard_provisioning_test.go b/pkg/services/sqlstore/dashboard_provisioning_test.go index b58b62e573b..191091e5285 100644 --- a/pkg/services/sqlstore/dashboard_provisioning_test.go +++ b/pkg/services/sqlstore/dashboard_provisioning_test.go @@ -80,7 +80,7 @@ func TestDashboardProvisioningTest(t *testing.T) { require.NotNil(t, query.Result) deleteCmd := &models.DeleteOrphanedProvisionedDashboardsCommand{ReaderNames: []string{"default"}} - require.Nil(t, DeleteOrphanedProvisionedDashboards(context.Background(), deleteCmd)) + require.Nil(t, sqlStore.DeleteOrphanedProvisionedDashboards(context.Background(), deleteCmd)) query = &models.GetDashboardsQuery{DashboardIds: []int64{dash.Id, anotherDash.Id}} err = GetDashboards(context.Background(), query) @@ -117,7 +117,7 @@ func TestDashboardProvisioningTest(t *testing.T) { OrgId: 1, } - require.Nil(t, DeleteDashboard(context.Background(), deleteCmd)) + require.Nil(t, sqlStore.DeleteDashboard(context.Background(), deleteCmd)) data, err := sqlStore.GetProvisionedDataByDashboardID(dash.Id) require.Nil(t, err) diff --git a/pkg/services/sqlstore/dashboard_test.go b/pkg/services/sqlstore/dashboard_test.go index 616af63c0d4..07cff5f5d8a 100644 --- a/pkg/services/sqlstore/dashboard_test.go +++ b/pkg/services/sqlstore/dashboard_test.go @@ -117,7 +117,7 @@ func TestDashboardDataAccess(t *testing.T) { setup() dash := insertTestDashboard(t, sqlStore, "delete me", 1, 0, false, "delete this") - err := DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{ + err := sqlStore.DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{ Id: dash.Id, OrgId: 1, }) @@ -214,21 +214,21 @@ func TestDashboardDataAccess(t *testing.T) { emptyFolder := insertTestDashboard(t, sqlStore, "2 test dash folder", 1, 0, true, "prod", "webapp") deleteCmd := &models.DeleteDashboardCommand{Id: emptyFolder.Id} - err := DeleteDashboard(context.Background(), deleteCmd) + err := sqlStore.DeleteDashboard(context.Background(), deleteCmd) require.NoError(t, err) }) t.Run("Should be not able to delete a dashboard if force delete rules is disabled", func(t *testing.T) { setup() deleteCmd := &models.DeleteDashboardCommand{Id: savedFolder.Id, ForceDeleteFolderRules: false} - err := DeleteDashboard(context.Background(), deleteCmd) + err := sqlStore.DeleteDashboard(context.Background(), deleteCmd) require.True(t, errors.Is(err, models.ErrFolderContainsAlertRules)) }) t.Run("Should be able to delete a dashboard folder and its children if force delete rules is enabled", func(t *testing.T) { setup() deleteCmd := &models.DeleteDashboardCommand{Id: savedFolder.Id, ForceDeleteFolderRules: true} - err := DeleteDashboard(context.Background(), deleteCmd) + err := sqlStore.DeleteDashboard(context.Background(), deleteCmd) require.NoError(t, err) query := search.FindPersistedDashboardsQuery{ diff --git a/pkg/services/sqlstore/login_attempt.go b/pkg/services/sqlstore/login_attempt.go index 58d0ddf38bf..cf2b78dac2e 100644 --- a/pkg/services/sqlstore/login_attempt.go +++ b/pkg/services/sqlstore/login_attempt.go @@ -11,14 +11,14 @@ import ( var getTimeNow = time.Now -func init() { - bus.AddHandler("sql", CreateLoginAttempt) - bus.AddHandler("sql", DeleteOldLoginAttempts) +func (ss *SQLStore) addLoginAttemptQueryAndCommandHandlers() { + bus.AddHandler("sql", ss.CreateLoginAttempt) + bus.AddHandler("sql", ss.DeleteOldLoginAttempts) bus.AddHandler("sql", GetUserLoginAttemptCount) } -func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { loginAttempt := models.LoginAttempt{ Username: cmd.Username, IpAddress: cmd.IpAddress, @@ -35,8 +35,8 @@ func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptComma }) } -func DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { var maxId int64 sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?" result, err := sess.Query(sql, cmd.OlderThan.Unix()) diff --git a/pkg/services/sqlstore/login_attempt_test.go b/pkg/services/sqlstore/login_attempt_test.go index dbe16866a38..5d8f0a4bd1f 100644 --- a/pkg/services/sqlstore/login_attempt_test.go +++ b/pkg/services/sqlstore/login_attempt_test.go @@ -20,24 +20,25 @@ func mockTime(mock time.Time) time.Time { func TestLoginAttempts(t *testing.T) { var beginningOfTime, timePlusOneMinute, timePlusTwoMinutes time.Time + var sqlStore *SQLStore user := "user" setup := func(t *testing.T) { - InitTestDB(t) + sqlStore = InitTestDB(t) beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)) - err := CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ + err := sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ Username: user, IpAddress: "192.168.0.1", }) require.Nil(t, err) timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1)) - err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ + err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ Username: user, IpAddress: "192.168.0.1", }) require.Nil(t, err) timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2)) - err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ + err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ Username: user, IpAddress: "192.168.0.1", }) @@ -93,7 +94,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: beginningOfTime, } - err := DeleteOldLoginAttempts(context.Background(), &cmd) + err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(0), cmd.DeletedRows) @@ -104,7 +105,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: timePlusOneMinute, } - err := DeleteOldLoginAttempts(context.Background(), &cmd) + err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(1), cmd.DeletedRows) @@ -115,7 +116,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: timePlusTwoMinutes, } - err := DeleteOldLoginAttempts(context.Background(), &cmd) + err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(2), cmd.DeletedRows) @@ -126,7 +127,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: timePlusTwoMinutes.Add(time.Second * 1), } - err := DeleteOldLoginAttempts(context.Background(), &cmd) + err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(3), cmd.DeletedRows) diff --git a/pkg/services/sqlstore/org.go b/pkg/services/sqlstore/org.go index ae011aec032..cf90e4b3b6b 100644 --- a/pkg/services/sqlstore/org.go +++ b/pkg/services/sqlstore/org.go @@ -15,14 +15,14 @@ import ( // MainOrgName is the name of the main organization. const MainOrgName = "Main Org." -func init() { +func (ss *SQLStore) addOrgQueryAndCommandHandlers() { bus.AddHandler("sql", GetOrgById) bus.AddHandler("sql", CreateOrg) - bus.AddHandler("sql", UpdateOrg) - bus.AddHandler("sql", UpdateOrgAddress) + bus.AddHandler("sql", ss.UpdateOrg) + bus.AddHandler("sql", ss.UpdateOrgAddress) bus.AddHandler("sql", GetOrgByName) bus.AddHandler("sql", SearchOrgs) - bus.AddHandler("sql", DeleteOrg) + bus.AddHandler("sql", ss.DeleteOrg) } func SearchOrgs(ctx context.Context, query *models.SearchOrgsQuery) error { @@ -164,8 +164,8 @@ func CreateOrg(ctx context.Context, cmd *models.CreateOrgCommand) error { return nil } -func UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { if isNameTaken, err := isOrgNameTaken(cmd.Name, cmd.OrgId, sess); err != nil { return err } else if isNameTaken { @@ -197,8 +197,8 @@ func UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error { }) } -func UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { org := models.Org{ Address1: cmd.Address1, Address2: cmd.Address2, @@ -224,8 +224,8 @@ func UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand) }) } -func DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { if res, err := sess.Query("SELECT 1 from org WHERE id=?", cmd.Id); err != nil { return err } else if len(res) != 1 { diff --git a/pkg/services/sqlstore/org_test.go b/pkg/services/sqlstore/org_test.go index e58cc4f6b52..0cda5495fcb 100644 --- a/pkg/services/sqlstore/org_test.go +++ b/pkg/services/sqlstore/org_test.go @@ -195,7 +195,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Can get logged in user projection", func(t *testing.T) { query := models.GetSignedInUserQuery{UserId: ac2.Id} - err := GetSignedInUser(context.Background(), &query) + err := sqlStore.GetSignedInUser(context.Background(), &query) require.NoError(t, err) require.Equal(t, query.Result.Email, "ac2@test.com") @@ -256,7 +256,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("SignedInUserQuery with a different org", func(t *testing.T) { query := models.GetSignedInUserQuery{UserId: ac2.Id} - err := GetSignedInUser(context.Background(), &query) + err := sqlStore.GetSignedInUser(context.Background(), &query) require.NoError(t, err) require.Equal(t, query.Result.OrgId, ac1.OrgId) @@ -273,7 +273,7 @@ func TestAccountDataAccess(t *testing.T) { require.NoError(t, err) query := models.GetSignedInUserQuery{UserId: ac2.Id} - err = GetSignedInUser(context.Background(), &query) + err = sqlStore.GetSignedInUser(context.Background(), &query) require.NoError(t, err) require.Equal(t, query.Result.OrgId, ac2.OrgId) @@ -282,7 +282,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Removing user from org should delete user completely if in no other org", func(t *testing.T) { // make sure ac2 has no org - err := DeleteOrg(context.Background(), &models.DeleteOrgCommand{Id: ac2.OrgId}) + err := sqlStore.DeleteOrg(context.Background(), &models.DeleteOrgCommand{Id: ac2.OrgId}) require.NoError(t, err) // remove ac2 user from ac1 org @@ -291,7 +291,7 @@ func TestAccountDataAccess(t *testing.T) { require.NoError(t, err) require.True(t, remCmd.UserWasDeleted) - err = GetSignedInUser(context.Background(), &models.GetSignedInUserQuery{UserId: ac2.Id}) + err = sqlStore.GetSignedInUser(context.Background(), &models.GetSignedInUserQuery{UserId: ac2.Id}) require.Equal(t, err, models.ErrUserNotFound) }) diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index 2b73bcb6296..1b41ce72f99 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -127,7 +127,10 @@ func newSQLStore(cfg *setting.Cfg, cacheService *localcache.CacheService, bus bu ss.addDashboardVersionQueryAndCommandHandlers() ss.addAPIKeysQueryAndCommandHandlers() ss.addPlaylistQueryAndCommandHandlers() + ss.addLoginAttemptQueryAndCommandHandlers() ss.addTeamQueryAndCommandHandlers() + ss.addDashboardProvisioningQueryAndCommandHandlers() + ss.addOrgQueryAndCommandHandlers() // if err := ss.Reset(); err != nil { // return nil, err diff --git a/pkg/services/sqlstore/team.go b/pkg/services/sqlstore/team.go index 91c9656467c..a1fd714d8fd 100644 --- a/pkg/services/sqlstore/team.go +++ b/pkg/services/sqlstore/team.go @@ -16,7 +16,7 @@ func (ss *SQLStore) addTeamQueryAndCommandHandlers() { bus.AddHandler("sql", ss.DeleteTeam) bus.AddHandler("sql", ss.SearchTeams) bus.AddHandler("sql", ss.GetTeamById) - bus.AddHandler("sql", GetTeamsByUser) + bus.AddHandler("sql", ss.GetTeamsByUser) bus.AddHandler("sql", ss.UpdateTeamMember) bus.AddHandler("sql", ss.RemoveTeamMember) @@ -106,7 +106,7 @@ func (ss *SQLStore) CreateTeam(name, email string, orgID int64) (models.Team, er } func (ss *SQLStore) UpdateTeam(ctx context.Context, cmd *models.UpdateTeamCommand) error { - return inTransaction(func(sess *DBSession) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { if isNameTaken, err := isTeamNameTaken(cmd.OrgId, cmd.Name, cmd.Id, sess); err != nil { return err } else if isNameTaken { @@ -137,7 +137,7 @@ func (ss *SQLStore) UpdateTeam(ctx context.Context, cmd *models.UpdateTeamComman // DeleteTeam will delete a team, its member and any permissions connected to the team func (ss *SQLStore) DeleteTeam(ctx context.Context, cmd *models.DeleteTeamCommand) error { - return inTransaction(func(sess *DBSession) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { if _, err := teamExists(cmd.OrgId, cmd.Id, sess); err != nil { return err } @@ -274,17 +274,19 @@ func (ss *SQLStore) GetTeamById(ctx context.Context, query *models.GetTeamByIdQu } // GetTeamsByUser is used by the Guardian when checking a users' permissions -func GetTeamsByUser(ctx context.Context, query *models.GetTeamsByUserQuery) error { - query.Result = make([]*models.TeamDTO, 0) +func (ss *SQLStore) GetTeamsByUser(ctx context.Context, query *models.GetTeamsByUserQuery) error { + return ss.WithDbSession(ctx, func(sess *DBSession) error { + query.Result = make([]*models.TeamDTO, 0) - var sql bytes.Buffer + var sql bytes.Buffer - sql.WriteString(getTeamSelectSQLBase([]string{})) - sql.WriteString(` INNER JOIN team_member on team.id = team_member.team_id`) - sql.WriteString(` WHERE team.org_id = ? and team_member.user_id = ?`) + sql.WriteString(getTeamSelectSQLBase([]string{})) + sql.WriteString(` INNER JOIN team_member on team.id = team_member.team_id`) + sql.WriteString(` WHERE team.org_id = ? and team_member.user_id = ?`) - err := x.SQL(sql.String(), query.OrgId, query.UserId).Find(&query.Result) - return err + err := sess.SQL(sql.String(), query.OrgId, query.UserId).Find(&query.Result) + return err + }) } // AddTeamMember adds a user to a team @@ -333,7 +335,7 @@ func getTeamMember(sess *DBSession, orgId int64, teamId int64, userId int64) (mo // UpdateTeamMember updates a team member func (ss *SQLStore) UpdateTeamMember(ctx context.Context, cmd *models.UpdateTeamMemberCommand) error { - return inTransaction(func(sess *DBSession) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { member, err := getTeamMember(sess, cmd.OrgId, cmd.TeamId, cmd.UserId) if err != nil { return err @@ -359,7 +361,7 @@ func (ss *SQLStore) UpdateTeamMember(ctx context.Context, cmd *models.UpdateTeam // RemoveTeamMember removes a member from a team func (ss *SQLStore) RemoveTeamMember(ctx context.Context, cmd *models.RemoveTeamMemberCommand) error { - return inTransaction(func(sess *DBSession) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { if _, err := teamExists(cmd.OrgId, cmd.TeamId, sess); err != nil { return err } diff --git a/pkg/services/sqlstore/team_test.go b/pkg/services/sqlstore/team_test.go index ec788677d24..f555202c92e 100644 --- a/pkg/services/sqlstore/team_test.go +++ b/pkg/services/sqlstore/team_test.go @@ -209,7 +209,7 @@ func TestTeamCommandsAndQueries(t *testing.T) { require.NoError(t, err) query := &models.GetTeamsByUserQuery{OrgId: testOrgID, UserId: userIds[0]} - err = GetTeamsByUser(context.Background(), query) + err = sqlStore.GetTeamsByUser(context.Background(), query) require.NoError(t, err) require.Equal(t, len(query.Result), 1) require.Equal(t, query.Result[0].Name, "group2 name") diff --git a/pkg/services/sqlstore/user.go b/pkg/services/sqlstore/user.go index a9b9bf78897..5f3d7f35a0b 100644 --- a/pkg/services/sqlstore/user.go +++ b/pkg/services/sqlstore/user.go @@ -546,7 +546,7 @@ func (ss *SQLStore) GetSignedInUserWithCacheCtx(ctx context.Context, query *mode return nil } - err := GetSignedInUser(ctx, query) + err := ss.GetSignedInUser(ctx, query) if err != nil { return err } @@ -556,7 +556,7 @@ func (ss *SQLStore) GetSignedInUserWithCacheCtx(ctx context.Context, query *mode return nil } -func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error { +func (ss *SQLStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error { orgId := "u.org_id" if query.OrgId > 0 { orgId = strconv.FormatInt(query.OrgId, 10) @@ -603,7 +603,7 @@ func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) er } getTeamsByUserQuery := &models.GetTeamsByUserQuery{OrgId: user.OrgId, UserId: user.UserId} - err = GetTeamsByUser(ctx, getTeamsByUserQuery) + err = ss.GetTeamsByUser(ctx, getTeamsByUserQuery) if err != nil { return err }