From 44282134da2afc5047e62bf6c5158fcee7df5446 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Wed, 16 Oct 2024 13:40:35 -0300 Subject: [PATCH] Unistore Chore: Make it easier to implement DB interface (#94680) make it easier to implement DB interface --- pkg/storage/unified/sql/db/dbimpl/db.go | 51 ++----- pkg/storage/unified/sql/db/dbimpl/db_test.go | 150 ++----------------- pkg/storage/unified/sql/db/service.go | 67 +++++++++ pkg/storage/unified/sql/db/service_test.go | 97 ++++++++++++ 4 files changed, 185 insertions(+), 180 deletions(-) create mode 100644 pkg/storage/unified/sql/db/service_test.go diff --git a/pkg/storage/unified/sql/db/dbimpl/db.go b/pkg/storage/unified/sql/db/dbimpl/db.go index c7f5f616e60..c3ca9f4137c 100644 --- a/pkg/storage/unified/sql/db/dbimpl/db.go +++ b/pkg/storage/unified/sql/db/dbimpl/db.go @@ -3,62 +3,31 @@ package dbimpl import ( "context" "database/sql" - "fmt" - "strings" "github.com/grafana/grafana/pkg/storage/unified/sql/db" ) +// NewDB converts a *sql.DB to a db.DB. func NewDB(d *sql.DB, driverName string) db.DB { - // remove the suffix from the instrumented driver created by the older - // Grafana code - driverName = strings.TrimSuffix(driverName, "WithHooks") - - return sqldb{ + ret := sqlDB{ DB: d, driverName: driverName, } + ret.WithTxFunc = db.NewWithTxFunc(ret.BeginTx) + + return ret } -type sqldb struct { +type sqlDB struct { *sql.DB + db.WithTxFunc driverName string } -func (d sqldb) DriverName() string { +func (d sqlDB) DriverName() string { return d.driverName } -func (d sqldb) BeginTx(ctx context.Context, opts *sql.TxOptions) (db.Tx, error) { - t, err := d.DB.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - return tx{ - Tx: t, - }, nil -} - -func (d sqldb) WithTx(ctx context.Context, opts *sql.TxOptions, f db.TxFunc) error { - t, err := d.BeginTx(ctx, opts) - if err != nil { - return fmt.Errorf("begin tx: %w", err) - } - - if err := f(ctx, t); err != nil { - if rollbackErr := t.Rollback(); rollbackErr != nil { - return fmt.Errorf("tx err: %w; rollback err: %w", err, rollbackErr) - } - return fmt.Errorf("tx err: %w", err) - } - - if err = t.Commit(); err != nil { - return fmt.Errorf("commit err: %w", err) - } - - return nil -} - -type tx struct { - *sql.Tx +func (d sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (db.Tx, error) { + return d.DB.BeginTx(ctx, opts) } diff --git a/pkg/storage/unified/sql/db/dbimpl/db_test.go b/pkg/storage/unified/sql/db/dbimpl/db_test.go index 0e85ebeb277..a9d550fb3b8 100644 --- a/pkg/storage/unified/sql/db/dbimpl/db_test.go +++ b/pkg/storage/unified/sql/db/dbimpl/db_test.go @@ -1,154 +1,26 @@ package dbimpl import ( - "context" - "errors" + "database/sql" "testing" - "time" - sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/storage/unified/sql/db" + "github.com/grafana/grafana/pkg/util/testutil" ) -func newCtx(t *testing.T) context.Context { - t.Helper() - - d, ok := t.Deadline() - if !ok { - // provide a default timeout for tests - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - t.Cleanup(cancel) - - return ctx - } - - ctx, cancel := context.WithDeadline(context.Background(), d) - t.Cleanup(cancel) - - return ctx -} - -var errTest = errors.New("because of reasons") - -const driverName = "sqlmock" - func TestDB_BeginTx(t *testing.T) { t.Parallel() + ctx := testutil.NewDefaultTestContext(t) - t.Run("happy path", func(t *testing.T) { - t.Parallel() + sqlDB, err := sql.Open(driverWithIsolationLevelName, "") + require.NoError(t, err) + require.NotNil(t, sqlDB) - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, driverName) - require.Equal(t, driverName, db.DriverName()) + d := NewDB(sqlDB, driverWithIsolationLevelName) + require.Equal(t, driverWithIsolationLevelName, d.DriverName()) - mock.ExpectBegin() - tx, err := db.BeginTx(newCtx(t), nil) - - require.NoError(t, err) - require.NotNil(t, tx) - }) - - t.Run("fail begin", func(t *testing.T) { - t.Parallel() - - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, "sqlmock") - - mock.ExpectBegin().WillReturnError(errTest) - tx, err := db.BeginTx(newCtx(t), nil) - - require.Nil(t, tx) - require.Error(t, err) - require.ErrorIs(t, err, errTest) - }) -} - -func TestDB_WithTx(t *testing.T) { - t.Parallel() - - newTxFunc := func(err error) db.TxFunc { - return func(context.Context, db.Tx) error { - return err - } - } - - t.Run("happy path", func(t *testing.T) { - t.Parallel() - - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, "sqlmock") - - mock.ExpectBegin() - mock.ExpectCommit() - err = db.WithTx(newCtx(t), nil, newTxFunc(nil)) - - require.NoError(t, err) - }) - - t.Run("fail begin", func(t *testing.T) { - t.Parallel() - - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, "sqlmock") - - mock.ExpectBegin().WillReturnError(errTest) - err = db.WithTx(newCtx(t), nil, newTxFunc(nil)) - - require.Error(t, err) - require.ErrorIs(t, err, errTest) - }) - - t.Run("fail tx", func(t *testing.T) { - t.Parallel() - - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, "sqlmock") - - mock.ExpectBegin() - mock.ExpectRollback() - err = db.WithTx(newCtx(t), nil, newTxFunc(errTest)) - - require.Error(t, err) - require.ErrorIs(t, err, errTest) - }) - - t.Run("fail tx; fail rollback", func(t *testing.T) { - t.Parallel() - - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, "sqlmock") - errTest2 := errors.New("yet another err") - - mock.ExpectBegin() - mock.ExpectRollback().WillReturnError(errTest) - err = db.WithTx(newCtx(t), nil, newTxFunc(errTest2)) - - require.Error(t, err) - require.ErrorIs(t, err, errTest) - require.ErrorIs(t, err, errTest2) - }) - - t.Run("fail commit", func(t *testing.T) { - t.Parallel() - - sqldb, mock, err := sqlmock.New() - require.NoError(t, err) - db := NewDB(sqldb, "sqlmock") - - mock.ExpectBegin() - mock.ExpectCommit().WillReturnError(errTest) - err = db.WithTx(newCtx(t), nil, newTxFunc(nil)) - - require.Error(t, err) - require.ErrorIs(t, err, errTest) - }) + tx, err := d.BeginTx(ctx, nil) + require.NoError(t, err) + require.NotNil(t, tx) } diff --git a/pkg/storage/unified/sql/db/service.go b/pkg/storage/unified/sql/db/service.go index 09f4c664b67..e22f265820a 100755 --- a/pkg/storage/unified/sql/db/service.go +++ b/pkg/storage/unified/sql/db/service.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "fmt" ) //go:generate mockery --with-expecter --name DB @@ -62,3 +63,69 @@ type ContextExecer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } + +// WithTxFunc is an adapter to be able to provide the DB.WithTx method as an +// embedded function. +type WithTxFunc func(context.Context, *sql.TxOptions, TxFunc) error + +// WithTx implements the DB.WithTx method. +func (x WithTxFunc) WithTx(ctx context.Context, opts *sql.TxOptions, f TxFunc) error { + return x(ctx, opts, f) +} + +// BeginTxFunc is the signature of the DB.BeginTx method. +type BeginTxFunc = func(context.Context, *sql.TxOptions) (Tx, error) + +// NewWithTxFunc provides implementations of DB an easy way to provide the +// DB.WithTx method. +// Example usage: +// +// type myDB struct { +// db.WithTxFunc // embedded so that `WithTx` is already provided +// // other members... +// } +// +// func NewMyDB(/* options */) (db.DB, error) { +// ret := new(myDB) +// ret.WithTxFunc = db.NewWithTxFunc(ret.BeginTx) +// // other initialization code ... +// return ret, nil +// } +func NewWithTxFunc(x BeginTxFunc) WithTxFunc { + return WithTxFunc( + func(ctx context.Context, opts *sql.TxOptions, f TxFunc) error { + t, err := x(ctx, opts) + if err != nil { + return fmt.Errorf(oneErrFmt, beginStr, err) + } + + if err := f(ctx, t); err != nil { + if rollbackErr := t.Rollback(); rollbackErr != nil { + return fmt.Errorf(twoErrFmt, txOpStr, err, rollbackStr, + rollbackErr) + } + return fmt.Errorf(oneErrFmt, txOpStr, err) + } + + if err = t.Commit(); err != nil { + return fmt.Errorf(oneErrFmt, commitStr, err) + } + + return nil + }, + ) +} + +// Constants that allow testing that the correct scenario was hit. +const ( + oneErrFmt = "%s: %w" + twoErrFmt = oneErrFmt + "; " + oneErrFmt + + // keep the following ones in sync with the matching ones in + // `service_test.go`. + + txOpStr = "transactional operation" + beginStr = "begin" + commitStr = "commit" + rollbackStr = "rollback" +) diff --git a/pkg/storage/unified/sql/db/service_test.go b/pkg/storage/unified/sql/db/service_test.go new file mode 100644 index 00000000000..3dafddf833d --- /dev/null +++ b/pkg/storage/unified/sql/db/service_test.go @@ -0,0 +1,97 @@ +package db_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/grafana/grafana/pkg/storage/unified/sql/db" + "github.com/grafana/grafana/pkg/storage/unified/sql/db/mocks" + "github.com/grafana/grafana/pkg/util/testutil" +) + +var errTest = errors.New("you shall not pass") + +// Copy-paste of the constants used in `service.go`, since we need to use a +// separate package to avoid circular dependencies so we cannot import them. +// Keep these ones and the ones in `service.go` in sync. +const ( + txOpStr = "transactional operation" + beginStr = "begin" + commitStr = "commit" + rollbackStr = "rollback" +) + +func TestNewWithTxFunc(t *testing.T) { + t.Parallel() + + execTest := func(t *testing.T, d db.DB, txErr error) error { + ctx := testutil.NewDefaultTestContext(t) + return db.NewWithTxFunc(d.BeginTx).WithTx(ctx, nil, + func(context.Context, db.Tx) error { + return txErr + }) + } + + t.Run("happy path", func(t *testing.T) { + t.Parallel() + mDB, mTx := mocks.NewDB(t), mocks.NewTx(t) + + mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil) + mTx.EXPECT().Commit().Return(nil) + + err := execTest(t, mDB, nil) + require.NoError(t, err) + }) + + t.Run("failed begin", func(t *testing.T) { + t.Parallel() + mDB := mocks.NewDB(t) + + mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(nil, errTest) + + err := execTest(t, mDB, nil) + require.Error(t, err) + require.ErrorContains(t, err, beginStr) + }) + + t.Run("fail tx", func(t *testing.T) { + t.Parallel() + mDB, mTx := mocks.NewDB(t), mocks.NewTx(t) + + mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil) + mTx.EXPECT().Rollback().Return(nil) + + err := execTest(t, mDB, errTest) + require.Error(t, err) + require.ErrorContains(t, err, txOpStr) + }) + + t.Run("fail tx; fail rollback", func(t *testing.T) { + t.Parallel() + mDB, mTx := mocks.NewDB(t), mocks.NewTx(t) + + mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil) + mTx.EXPECT().Rollback().Return(errTest) + + err := execTest(t, mDB, errTest) + require.Error(t, err) + require.ErrorContains(t, err, txOpStr) + require.ErrorContains(t, err, rollbackStr) + }) + + t.Run("fail commit", func(t *testing.T) { + t.Parallel() + mDB, mTx := mocks.NewDB(t), mocks.NewTx(t) + + mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil) + mTx.EXPECT().Commit().Return(errTest) + + err := execTest(t, mDB, nil) + require.Error(t, err) + require.ErrorContains(t, err, commitStr) + }) +}