ResourceServer: Resource store sql backend (#90170)

This commit is contained in:
Georges Chaudy
2024-07-18 17:03:18 +02:00
committed by GitHub
parent bb40fb342a
commit 08c611c68b
71 changed files with 2871 additions and 35 deletions
+59
View File
@@ -0,0 +1,59 @@
package dbimpl
import (
"context"
"database/sql"
"fmt"
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
)
func NewDB(d *sql.DB, driverName string) resourcedb.DB {
return sqldb{
DB: d,
driverName: driverName,
}
}
type sqldb struct {
*sql.DB
driverName string
}
func (d sqldb) DriverName() string {
return d.driverName
}
func (d sqldb) BeginTx(ctx context.Context, opts *sql.TxOptions) (resourcedb.Tx, error) {
t, err := d.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return tx{
Tx: t,
}, nil
}
func (d sqldb) WithTx(ctx context.Context, opts *sql.TxOptions, f resourcedb.TxFunc) error {
t, err := d.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
if err := f(ctx, t); err != nil {
if rollbackErr := t.Rollback(); rollbackErr != nil {
return fmt.Errorf("tx err: %w; rollback err: %w", err, rollbackErr)
}
return fmt.Errorf("tx err: %w", err)
}
if err = t.Commit(); err != nil {
return fmt.Errorf("commit err: %w", err)
}
return nil
}
type tx struct {
*sql.Tx
}
@@ -0,0 +1,105 @@
package dbimpl
import (
"cmp"
"fmt"
"strings"
"time"
"github.com/go-sql-driver/mysql"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/store/entity/db"
)
func getEngineMySQL(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, error) {
config := mysql.NewConfig()
config.User = getter.String("db_user")
config.Passwd = getter.String("db_pass")
config.Net = "tcp"
config.Addr = getter.String("db_host")
config.DBName = getter.String("db_name")
config.Params = map[string]string{
// See: https://dev.mysql.com/doc/refman/en/sql-mode.html
"@@SESSION.sql_mode": "ANSI",
}
config.Collation = "utf8mb4_unicode_ci"
config.Loc = time.UTC
config.AllowNativePasswords = true
config.ClientFoundRows = true
// TODO: do we want to support these?
// config.ServerPubKey = getter.String("db_server_pub_key")
// config.TLSConfig = getter.String("db_tls_config_name")
if err := getter.Err(); err != nil {
return nil, fmt.Errorf("config error: %w", err)
}
if strings.HasPrefix(config.Addr, "/") {
config.Net = "unix"
}
// FIXME: get rid of xorm
engine, err := xorm.NewEngine(db.DriverMySQL, config.FormatDSN())
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
engine.SetMaxOpenConns(0)
engine.SetMaxIdleConns(2)
engine.SetConnMaxLifetime(4 * time.Hour)
return engine, nil
}
func getEnginePostgres(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, error) {
dsnKV := map[string]string{
"user": getter.String("db_user"),
"password": getter.String("db_pass"),
"dbname": getter.String("db_name"),
"sslmode": cmp.Or(getter.String("db_sslmode"), "disable"),
}
// TODO: probably interesting:
// "passfile", "statement_timeout", "lock_timeout", "connect_timeout"
// TODO: for CockroachDB, we probably need to use the following:
// dsnKV["options"] = "-c enable_experimental_alter_column_type_general=true"
// Or otherwise specify it as:
// dsnKV["enable_experimental_alter_column_type_general"] = "true"
// TODO: do we want to support these options in the DSN as well?
// "sslkey", "sslcert", "sslrootcert", "sslpassword", "sslsni", "krbspn",
// "krbsrvname", "target_session_attrs", "service", "servicefile"
// More on Postgres connection string parameters:
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
hostport := getter.String("db_host")
if err := getter.Err(); err != nil {
return nil, fmt.Errorf("config error: %w", err)
}
host, port, err := splitHostPortDefault(hostport, "127.0.0.1", "5432")
if err != nil {
return nil, fmt.Errorf("invalid db_host: %w", err)
}
dsnKV["host"] = host
dsnKV["port"] = port
dsn, err := MakeDSN(dsnKV)
if err != nil {
return nil, fmt.Errorf("error building DSN: %w", err)
}
// FIXME: get rid of xorm
engine, err := xorm.NewEngine(db.DriverPostgres, dsn)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
return engine, nil
}
@@ -0,0 +1,92 @@
package dbimpl
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetEngineMySQLFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": "/var/run/mysql.socket",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
engine, err := getEngineMySQL(getter, nil)
assert.NotNil(t, engine)
assert.NoError(t, err)
})
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": "/var/run/mysql.socket",
"db_name": string(invalidUTF8ByteSequence),
"db_user": "user",
"db_password": "password",
})
engine, err := getEngineMySQL(getter, nil)
assert.Nil(t, engine)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidUTF8Sequence)
})
}
func TestGetEnginePostgresFromConfig(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": "localhost",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
engine, err := getEnginePostgres(getter, nil)
assert.NotNil(t, engine)
assert.NoError(t, err)
})
t.Run("invalid string", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": string(invalidUTF8ByteSequence),
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
engine, err := getEnginePostgres(getter, nil)
assert.Nil(t, engine)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidUTF8Sequence)
})
t.Run("invalid hostport", func(t *testing.T) {
t.Parallel()
getter := newTestSectionGetter(map[string]string{
"db_type": "mysql",
"db_host": "1:1:1",
"db_name": "grafana",
"db_user": "user",
"db_password": "password",
})
engine, err := getEnginePostgres(getter, nil)
assert.Nil(t, engine)
assert.Error(t, err)
})
}
@@ -0,0 +1,154 @@
package dbimpl
import (
"context"
"errors"
"testing"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
)
func newCtx(t *testing.T) context.Context {
t.Helper()
d, ok := t.Deadline()
if !ok {
// provide a default timeout for tests
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
return ctx
}
ctx, cancel := context.WithDeadline(context.Background(), d)
t.Cleanup(cancel)
return ctx
}
var errTest = errors.New("because of reasons")
const driverName = "sqlmock"
func TestDB_BeginTx(t *testing.T) {
t.Parallel()
t.Run("happy path", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, driverName)
require.Equal(t, driverName, db.DriverName())
mock.ExpectBegin()
tx, err := db.BeginTx(newCtx(t), nil)
require.NoError(t, err)
require.NotNil(t, tx)
})
t.Run("fail begin", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, "sqlmock")
mock.ExpectBegin().WillReturnError(errTest)
tx, err := db.BeginTx(newCtx(t), nil)
require.Nil(t, tx)
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
}
func TestDB_WithTx(t *testing.T) {
t.Parallel()
newTxFunc := func(err error) resourcedb.TxFunc {
return func(context.Context, resourcedb.Tx) error {
return err
}
}
t.Run("happy path", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, "sqlmock")
mock.ExpectBegin()
mock.ExpectCommit()
err = db.WithTx(newCtx(t), nil, newTxFunc(nil))
require.NoError(t, err)
})
t.Run("fail begin", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, "sqlmock")
mock.ExpectBegin().WillReturnError(errTest)
err = db.WithTx(newCtx(t), nil, newTxFunc(nil))
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
t.Run("fail tx", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, "sqlmock")
mock.ExpectBegin()
mock.ExpectRollback()
err = db.WithTx(newCtx(t), nil, newTxFunc(errTest))
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
t.Run("fail tx; fail rollback", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, "sqlmock")
errTest2 := errors.New("yet another err")
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errTest)
err = db.WithTx(newCtx(t), nil, newTxFunc(errTest2))
require.Error(t, err)
require.ErrorIs(t, err, errTest)
require.ErrorIs(t, err, errTest2)
})
t.Run("fail commit", func(t *testing.T) {
t.Parallel()
sqldb, mock, err := sqlmock.New()
require.NoError(t, err)
db := NewDB(sqldb, "sqlmock")
mock.ExpectBegin()
mock.ExpectCommit().WillReturnError(errTest)
err = db.WithTx(newCtx(t), nil, newTxFunc(nil))
require.Error(t, err)
require.ErrorIs(t, err, errTest)
})
}
+166
View File
@@ -0,0 +1,166 @@
package dbimpl
import (
"fmt"
"sync"
"github.com/dlmiddlecote/sqlstats"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"xorm.io/xorm"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"github.com/grafana/grafana/pkg/setting"
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
"github.com/grafana/grafana/pkg/storage/unified/sql/db/migrations"
)
var _ resourcedb.ResourceDBInterface = (*ResourceDB)(nil)
func ProvideResourceDB(db db.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer tracing.Tracer) (*ResourceDB, error) {
return &ResourceDB{
db: db,
cfg: cfg,
features: features,
log: log.New("entity-db"),
tracer: tracer,
}, nil
}
type ResourceDB struct {
once sync.Once
onceErr error
db db.DB
features featuremgmt.FeatureToggles
engine *xorm.Engine
cfg *setting.Cfg
log log.Logger
tracer tracing.Tracer
}
func (db *ResourceDB) Init() error {
db.once.Do(func() {
db.onceErr = db.init()
})
return db.onceErr
}
func (db *ResourceDB) GetEngine() (*xorm.Engine, error) {
if err := db.Init(); err != nil {
return nil, err
}
return db.engine, db.onceErr
}
func (db *ResourceDB) init() error {
if db.engine != nil {
return nil
}
var engine *xorm.Engine
var err error
// TODO: This should be renamed resource_api
getter := &sectionGetter{
DynamicSection: db.cfg.SectionWithEnvOverrides("resource_api"),
}
dbType := getter.Key("db_type").MustString("")
// if explicit connection settings are provided, use them
if dbType != "" {
if dbType == "postgres" {
engine, err = getEnginePostgres(getter, db.tracer)
if err != nil {
return err
}
// FIXME: this config option is cockroachdb-specific, it's not supported by postgres
// FIXME: this only sets this option for the session that we get
// from the pool right now. A *sql.DB is a pool of connections,
// there is no guarantee that the session where this is run will be
// the same where we need to change the type of a column
_, err = engine.Exec("SET SESSION enable_experimental_alter_column_type_general=true")
if err != nil {
db.log.Error("error connecting to postgres", "msg", err.Error())
// FIXME: return nil, err
}
} else if dbType == "mysql" {
engine, err = getEngineMySQL(getter, db.tracer)
if err != nil {
return err
}
if err = engine.Ping(); err != nil {
return err
}
} else {
// TODO: sqlite support
return fmt.Errorf("invalid db type specified: %s", dbType)
}
// register sql stat metrics
if err := prometheus.Register(sqlstats.NewStatsCollector("unified_storage", engine.DB().DB)); err != nil {
db.log.Warn("Failed to register unified storage sql stats collector", "error", err)
}
// configure sql logging
debugSQL := getter.Key("log_queries").MustBool(false)
if !debugSQL {
engine.SetLogger(&xorm.DiscardLogger{})
} else {
// add stack to database calls to be able to see what repository initiated queries. Top 7 items from the stack as they are likely in the xorm library.
// engine.SetLogger(sqlstore.NewXormLogger(log.LvlInfo, log.WithSuffix(log.New("sqlstore.xorm"), log.CallerContextKey, log.StackCaller(log.DefaultCallerDepth))))
engine.ShowSQL(true)
engine.ShowExecTime(true)
}
// otherwise, try to use the grafana db connection
} else {
if db.db == nil {
return fmt.Errorf("no db connection provided")
}
engine = db.db.GetEngine()
}
db.engine = engine
if err := migrations.MigrateResourceStore(engine, db.cfg, db.features); err != nil {
db.engine = nil
return fmt.Errorf("run migrations: %w", err)
}
return nil
}
func (db *ResourceDB) GetSession() (*session.SessionDB, error) {
engine, err := db.GetEngine()
if err != nil {
return nil, err
}
return session.GetSession(sqlx.NewDb(engine.DB().DB, engine.DriverName())), nil
}
func (db *ResourceDB) GetCfg() *setting.Cfg {
return db.cfg
}
func (db *ResourceDB) GetDB() (resourcedb.DB, error) {
engine, err := db.GetEngine()
if err != nil {
return nil, err
}
ret := NewDB(engine.DB().DB, engine.Dialect().DriverName())
return ret, nil
}
+111
View File
@@ -0,0 +1,111 @@
package dbimpl
import (
"cmp"
"errors"
"fmt"
"net"
"sort"
"strings"
"unicode/utf8"
"github.com/grafana/grafana/pkg/setting"
)
var (
ErrInvalidUTF8Sequence = errors.New("invalid UTF-8 sequence")
)
type sectionGetter struct {
*setting.DynamicSection
err error
}
func (g *sectionGetter) Err() error {
return g.err
}
func (g *sectionGetter) String(key string) string {
v := g.DynamicSection.Key(key).MustString("")
if !utf8.ValidString(v) {
g.err = fmt.Errorf("value for key %q: %w", key, ErrInvalidUTF8Sequence)
return ""
}
return v
}
// MakeDSN creates a DSN from the given key/value pair. It validates the strings
// form valid UTF-8 sequences and escapes values if needed.
func MakeDSN(m map[string]string) (string, error) {
b := new(strings.Builder)
ks := keys(m)
sort.Strings(ks) // provide deterministic behaviour
for _, k := range ks {
v := m[k]
if !utf8.ValidString(v) {
return "", fmt.Errorf("value for DSN key %q: %w", k,
ErrInvalidUTF8Sequence)
}
if v == "" {
continue
}
if b.Len() > 0 {
_ = b.WriteByte(' ')
}
_, _ = b.WriteString(k)
_ = b.WriteByte('=')
writeDSNValue(b, v)
}
return b.String(), nil
}
func keys(m map[string]string) []string {
ret := make([]string, 0, len(m))
for k := range m {
ret = append(ret, k)
}
return ret
}
func writeDSNValue(b *strings.Builder, v string) {
numq := strings.Count(v, `'`)
numb := strings.Count(v, `\`)
if numq+numb == 0 && v != "" {
b.WriteString(v)
return
}
b.Grow(2 + numq + numb + len(v))
_ = b.WriteByte('\'')
for _, r := range v {
if r == '\\' || r == '\'' {
_ = b.WriteByte('\\')
}
_, _ = b.WriteRune(r)
}
_ = b.WriteByte('\'')
}
// splitHostPortDefault is similar to net.SplitHostPort, but will also accept a
// specification with no port and apply the default port instead. It also
// applies the given defaults if the results are empty strings.
func splitHostPortDefault(hostport, defaultHost, defaultPort string) (string, string, error) {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
// try appending the port
host, port, err = net.SplitHostPort(hostport + ":" + defaultPort)
if err != nil {
return "", "", fmt.Errorf("invalid hostport: %q", hostport)
}
}
host = cmp.Or(host, defaultHost)
port = cmp.Or(port, defaultPort)
return host, port, nil
}
@@ -0,0 +1,108 @@
package dbimpl
import (
"fmt"
"testing"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
var invalidUTF8ByteSequence = []byte{0xff, 0xfe, 0xfd}
func setSectionKeyValues(section *setting.DynamicSection, m map[string]string) {
for k, v := range m {
section.Key(k).SetValue(v)
}
}
func newTestSectionGetter(m map[string]string) *sectionGetter {
section := setting.NewCfg().SectionWithEnvOverrides("entity_api")
setSectionKeyValues(section, m)
return &sectionGetter{
DynamicSection: section,
}
}
func TestSectionGetter(t *testing.T) {
t.Parallel()
var (
key = "the key"
val = string(invalidUTF8ByteSequence)
)
g := newTestSectionGetter(map[string]string{
key: val,
})
v := g.String("whatever")
require.Empty(t, v)
require.NoError(t, g.Err())
v = g.String(key)
require.Empty(t, v)
require.Error(t, g.Err())
require.ErrorIs(t, g.Err(), ErrInvalidUTF8Sequence)
}
func TestMakeDSN(t *testing.T) {
t.Parallel()
s, err := MakeDSN(map[string]string{
"db_name": string(invalidUTF8ByteSequence),
})
require.Empty(t, s)
require.Error(t, err)
require.ErrorIs(t, err, ErrInvalidUTF8Sequence)
s, err = MakeDSN(map[string]string{
"skip": "",
"user": `shou'ld esc\ape`,
"pass": "noescape",
})
require.NoError(t, err)
require.Equal(t, `pass=noescape user='shou\'ld esc\\ape'`, s)
}
func TestSplitHostPort(t *testing.T) {
t.Parallel()
testCases := []struct {
hostport string
defaultHost string
defaultPort string
fails bool
host string
port string
}{
{hostport: "192.168.0.140:456", defaultHost: "", defaultPort: "", host: "192.168.0.140", port: "456"},
{hostport: "192.168.0.140", defaultHost: "", defaultPort: "123", host: "192.168.0.140", port: "123"},
{hostport: "[::1]:456", defaultHost: "", defaultPort: "", host: "::1", port: "456"},
{hostport: "[::1]", defaultHost: "", defaultPort: "123", host: "::1", port: "123"},
{hostport: ":456", defaultHost: "1.2.3.4", defaultPort: "", host: "1.2.3.4", port: "456"},
{hostport: "xyz.rds.amazonaws.com", defaultHost: "", defaultPort: "123", host: "xyz.rds.amazonaws.com", port: "123"},
{hostport: "xyz.rds.amazonaws.com:123", defaultHost: "", defaultPort: "", host: "xyz.rds.amazonaws.com", port: "123"},
{hostport: "", defaultHost: "localhost", defaultPort: "1433", host: "localhost", port: "1433"},
{hostport: "1:1:1", fails: true},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test index #%d", i), func(t *testing.T) {
t.Parallel()
host, port, err := splitHostPortDefault(tc.hostport, tc.defaultHost, tc.defaultPort)
if tc.fails {
require.Error(t, err)
require.Empty(t, host)
require.Empty(t, port)
} else {
require.NoError(t, err)
require.Equal(t, tc.host, host)
require.Equal(t, tc.port, port)
}
})
}
}