From e076c74869cd602109bd0b4dcac246a7426d5d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20=C5=A0tibran=C3=BD?= Date: Thu, 3 Jul 2025 10:38:12 +0200 Subject: [PATCH] sqltemplate, dbimpl: Remove single-method function types (#107525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove dbProviderFunc function. This removes one extra indirection that made the code bit more difficult to navigate. * Remove indirection function types implementing single-method interfaces. This streamlines the code and makes it bit easier to navigate. * Update pkg/storage/unified/sql/sqltemplate/dialect_mysql.go Co-authored-by: Mustafa Sencer Özcan <32759850+mustafasencer@users.noreply.github.com> --------- Co-authored-by: Mustafa Sencer Özcan <32759850+mustafasencer@users.noreply.github.com> --- pkg/storage/unified/sql/db/dbimpl/dbimpl.go | 29 ++++++------ .../unified/sql/sqltemplate/args_test.go | 2 +- .../unified/sql/sqltemplate/dialect.go | 29 +----------- .../unified/sql/sqltemplate/dialect_mysql.go | 35 +++++++------- .../sql/sqltemplate/dialect_postgresql.go | 29 ++++++------ .../unified/sql/sqltemplate/dialect_sqlite.go | 26 +++++++---- .../unified/sql/sqltemplate/dialect_test.go | 46 ++----------------- 7 files changed, 69 insertions(+), 127 deletions(-) diff --git a/pkg/storage/unified/sql/db/dbimpl/dbimpl.go b/pkg/storage/unified/sql/db/dbimpl/dbimpl.go index 231e9525a6c..4db1ca56a85 100644 --- a/pkg/storage/unified/sql/db/dbimpl/dbimpl.go +++ b/pkg/storage/unified/sql/db/dbimpl/dbimpl.go @@ -43,21 +43,7 @@ func ProvideResourceDB(grafanaDB infraDB.DB, cfg *setting.Cfg, tracer trace.Trac if err != nil { return nil, fmt.Errorf("provide Resource DB: %w", err) } - var once sync.Once - var resourceDB db.DB - - return dbProviderFunc(func(ctx context.Context) (db.DB, error) { - once.Do(func() { - resourceDB, err = p.init(ctx) - }) - return resourceDB, err - }), nil -} - -type dbProviderFunc func(context.Context) (db.DB, error) - -func (f dbProviderFunc) Init(ctx context.Context) (db.DB, error) { - return f(ctx) + return p, nil } type resourceDBProvider struct { @@ -68,6 +54,10 @@ type resourceDBProvider struct { tracer trace.Tracer registerMetrics bool logQueries bool + + once sync.Once + resourceDB db.DB + initErr error } func newResourceDBProvider(grafanaDB infraDB.DB, cfg *setting.Cfg, tracer trace.Tracer) (p *resourceDBProvider, err error) { @@ -124,7 +114,14 @@ func newResourceDBProvider(grafanaDB infraDB.DB, cfg *setting.Cfg, tracer trace. } } -func (p *resourceDBProvider) init(ctx context.Context) (db.DB, error) { +func (p *resourceDBProvider) Init(ctx context.Context) (db.DB, error) { + p.once.Do(func() { + p.resourceDB, p.initErr = p.initDB(ctx) + }) + return p.resourceDB, p.initErr +} + +func (p *resourceDBProvider) initDB(ctx context.Context) (db.DB, error) { p.log.Info("Initializing Resource DB", "db_type", p.engine.Dialect().DriverName(), diff --git a/pkg/storage/unified/sql/sqltemplate/args_test.go b/pkg/storage/unified/sql/sqltemplate/args_test.go index 732b2c2915f..23c17f8dc74 100644 --- a/pkg/storage/unified/sql/sqltemplate/args_test.go +++ b/pkg/storage/unified/sql/sqltemplate/args_test.go @@ -71,7 +71,7 @@ func TestArg_ArgList(t *testing.T) { } var a args - a.d = argFmtSQL92 + a.d = MySQL for i, tc := range testCases { a.Reset() diff --git a/pkg/storage/unified/sql/sqltemplate/dialect.go b/pkg/storage/unified/sql/sqltemplate/dialect.go index 918545fdb25..2f2e74cf557 100644 --- a/pkg/storage/unified/sql/sqltemplate/dialect.go +++ b/pkg/storage/unified/sql/sqltemplate/dialect.go @@ -3,7 +3,6 @@ package sqltemplate import ( "bytes" "errors" - "strconv" "strings" ) @@ -92,7 +91,6 @@ func ParseRowLockingClause(s ...string) (RowLockingClause, error) { return opt, nil } -// Row-locking clause options. const ( SelectForShare RowLockingClause = "SHARE" SelectForShareNoWait RowLockingClause = "SHARE NOWAIT" @@ -129,9 +127,6 @@ var rowLockingClauseAll = rowLockingClauseMap{ SelectForUpdateSkipLocked: SelectForUpdateSkipLocked, } -// standardIdent provides standard SQL escaping of identifiers. -type standardIdent struct{} - func escapeIdentity(s string, quote rune, clean func(string) string) (string, error) { if s == "" { return "", ErrEmptyIdent @@ -154,31 +149,11 @@ func escapeIdentity(s string, quote rune, clean func(string) string) (string, er return buffer.String(), nil } -func (standardIdent) Ident(s string) (string, error) { +// standardIdent provides standard SQL escaping of identifiers. +func standardIdent(s string) (string, error) { return escapeIdentity(s, '"', func(s string) string { // not sure we should support escaping quotes in table/column names, // but it is valid so we will support it for now return strings.ReplaceAll(s, `"`, `""`) }) } - -type argPlaceholderFunc func(int) string - -func (f argPlaceholderFunc) ArgPlaceholder(argNum int) string { - return f(argNum) -} - -var ( - argFmtSQL92 = argPlaceholderFunc(func(int) string { - return "?" - }) - argFmtPositional = argPlaceholderFunc(func(argNum int) string { - return "$" + strconv.Itoa(argNum) - }) -) - -type name string - -func (n name) DialectName() string { - return string(n) -} diff --git a/pkg/storage/unified/sql/sqltemplate/dialect_mysql.go b/pkg/storage/unified/sql/sqltemplate/dialect_mysql.go index 14fd18e0456..d5c3aba0f41 100644 --- a/pkg/storage/unified/sql/sqltemplate/dialect_mysql.go +++ b/pkg/storage/unified/sql/sqltemplate/dialect_mysql.go @@ -6,26 +6,17 @@ import ( // MySQL is the default implementation of Dialect for the MySQL DMBS, // currently supporting MySQL-8.x. -var MySQL = mysql{ - rowLockingClauseMap: rowLockingClauseAll, - argPlaceholderFunc: argFmtSQL92, - name: "mysql", +var MySQL = mysql{} + +type mysql struct{} + +func (m mysql) DialectName() string { + return "mysql" } -var _ Dialect = MySQL - -type mysql struct { - backtickIdent - rowLockingClauseMap - argPlaceholderFunc - name -} - -// MySQL always supports backticks for identifiers -// https://dev.mysql.com/doc/refman/8.4/en/identifiers.html -type backtickIdent struct{} - -func (backtickIdent) Ident(s string) (string, error) { +func (m mysql) Ident(s string) (string, error) { + // MySQL always supports backticks for identifiers + // https://dev.mysql.com/doc/refman/8.4/en/identifiers.html if strings.ContainsRune(s, '`') { return "", ErrInvalidIdentInput } @@ -34,6 +25,14 @@ func (backtickIdent) Ident(s string) (string, error) { }) } +func (m mysql) ArgPlaceholder(argNum int) string { + return "?" +} + +func (m mysql) SelectFor(s ...string) (string, error) { + return rowLockingClauseAll.SelectFor(s...) +} + func (mysql) CurrentEpoch() string { return "CAST(FLOOR(UNIX_TIMESTAMP(NOW(6)) * 1000000) AS SIGNED)" } diff --git a/pkg/storage/unified/sql/sqltemplate/dialect_postgresql.go b/pkg/storage/unified/sql/sqltemplate/dialect_postgresql.go index a0dd76010eb..bdab9beea75 100644 --- a/pkg/storage/unified/sql/sqltemplate/dialect_postgresql.go +++ b/pkg/storage/unified/sql/sqltemplate/dialect_postgresql.go @@ -2,28 +2,29 @@ package sqltemplate import ( "errors" + "fmt" "strings" ) // PostgreSQL is an implementation of Dialect for the PostgreSQL DMBS. -var PostgreSQL = postgresql{ - rowLockingClauseMap: rowLockingClauseAll, - argPlaceholderFunc: argFmtPositional, - name: "postgres", -} +var PostgreSQL = postgresql{} -var _ Dialect = PostgreSQL - -// PostgreSQL-specific errors. var ( ErrPostgreSQLUnsupportedIdent = errors.New("identifiers in PostgreSQL cannot contain the character with code zero") ) -type postgresql struct { - standardIdent - rowLockingClauseMap - argPlaceholderFunc - name +type postgresql struct{} + +func (p postgresql) DialectName() string { + return "postgres" +} + +func (p postgresql) ArgPlaceholder(argNum int) string { + return fmt.Sprintf("$%d", argNum) +} + +func (p postgresql) SelectFor(s ...string) (string, error) { + return rowLockingClauseAll.SelectFor(s...) } func (p postgresql) Ident(s string) (string, error) { @@ -33,7 +34,7 @@ func (p postgresql) Ident(s string) (string, error) { return "", ErrPostgreSQLUnsupportedIdent } - return p.standardIdent.Ident(s) + return standardIdent(s) } func (postgresql) CurrentEpoch() string { diff --git a/pkg/storage/unified/sql/sqltemplate/dialect_sqlite.go b/pkg/storage/unified/sql/sqltemplate/dialect_sqlite.go index f84ea78c477..457b5f101a9 100644 --- a/pkg/storage/unified/sql/sqltemplate/dialect_sqlite.go +++ b/pkg/storage/unified/sql/sqltemplate/dialect_sqlite.go @@ -1,20 +1,26 @@ package sqltemplate // SQLite is an implementation of Dialect for the SQLite DMBS. -var SQLite = sqlite{ - argPlaceholderFunc: argFmtSQL92, - name: "sqlite", +var SQLite = sqlite{} + +type sqlite struct{} + +func (s sqlite) DialectName() string { + return "sqlite" } -var _ Dialect = SQLite - -type sqlite struct { +func (s sqlite) Ident(i string) (string, error) { // See: // https://www.sqlite.org/lang_keywords.html - standardIdent - rowLockingClauseMap - argPlaceholderFunc - name + return standardIdent(i) +} + +func (s sqlite) ArgPlaceholder(argNum int) string { + return "?" +} + +func (s sqlite) SelectFor(s2 ...string) (string, error) { + return rowLockingClauseMap(nil).SelectFor(s2...) } func (sqlite) CurrentEpoch() string { diff --git a/pkg/storage/unified/sql/sqltemplate/dialect_test.go b/pkg/storage/unified/sql/sqltemplate/dialect_test.go index 65ba2bd9c28..9c914892818 100644 --- a/pkg/storage/unified/sql/sqltemplate/dialect_test.go +++ b/pkg/storage/unified/sql/sqltemplate/dialect_test.go @@ -6,6 +6,10 @@ import ( "testing" ) +var _ Dialect = MySQL +var _ Dialect = SQLite +var _ Dialect = PostgreSQL + func TestSelectForOption_Valid(t *testing.T) { t.Parallel() @@ -133,7 +137,7 @@ func TestStandardIdent_Ident(t *testing.T) { } for i, tc := range testCases { - gotOutput, gotErr := standardIdent{}.Ident(tc.input) + gotOutput, gotErr := standardIdent(tc.input) if !errors.Is(gotErr, tc.err) { t.Fatalf("unexpected error %v in test case %d", gotErr, i) } @@ -142,43 +146,3 @@ func TestStandardIdent_Ident(t *testing.T) { } } } - -func TestArgPlaceholderFunc(t *testing.T) { - t.Parallel() - - testCases := []struct { - input int - valuePositional string - }{ - { - input: 1, - valuePositional: "$1", - }, - { - input: 16, - valuePositional: "$16", - }, - } - - for i, tc := range testCases { - got := argFmtSQL92(tc.input) - if got != "?" { - t.Fatalf("[argFmtSQL92] unexpected value %q in test case %d", got, i) - } - - got = argFmtPositional(tc.input) - if got != tc.valuePositional { - t.Fatalf("[argFmtPositional] unexpected value %q in test case %d", got, i) - } - } -} - -func TestName_Name(t *testing.T) { - t.Parallel() - - const v = "some dialect name" - n := name(v) - if n.DialectName() != v { - t.Fatalf("unexpected dialect name %q", n.DialectName()) - } -}