ResourceServer: Resource store sql backend (#90170)
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
|
||||
)
|
||||
|
||||
func NewDB(d *sql.DB, driverName string) resourcedb.DB {
|
||||
return sqldb{
|
||||
DB: d,
|
||||
driverName: driverName,
|
||||
}
|
||||
}
|
||||
|
||||
type sqldb struct {
|
||||
*sql.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
func (d sqldb) DriverName() string {
|
||||
return d.driverName
|
||||
}
|
||||
|
||||
func (d sqldb) BeginTx(ctx context.Context, opts *sql.TxOptions) (resourcedb.Tx, error) {
|
||||
t, err := d.DB.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx{
|
||||
Tx: t,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d sqldb) WithTx(ctx context.Context, opts *sql.TxOptions, f resourcedb.TxFunc) error {
|
||||
t, err := d.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
|
||||
if err := f(ctx, t); err != nil {
|
||||
if rollbackErr := t.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("tx err: %w; rollback err: %w", err, rollbackErr)
|
||||
}
|
||||
return fmt.Errorf("tx err: %w", err)
|
||||
}
|
||||
|
||||
if err = t.Commit(); err != nil {
|
||||
return fmt.Errorf("commit err: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type tx struct {
|
||||
*sql.Tx
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/services/store/entity/db"
|
||||
)
|
||||
|
||||
func getEngineMySQL(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, error) {
|
||||
config := mysql.NewConfig()
|
||||
config.User = getter.String("db_user")
|
||||
config.Passwd = getter.String("db_pass")
|
||||
config.Net = "tcp"
|
||||
config.Addr = getter.String("db_host")
|
||||
config.DBName = getter.String("db_name")
|
||||
config.Params = map[string]string{
|
||||
// See: https://dev.mysql.com/doc/refman/en/sql-mode.html
|
||||
"@@SESSION.sql_mode": "ANSI",
|
||||
}
|
||||
config.Collation = "utf8mb4_unicode_ci"
|
||||
config.Loc = time.UTC
|
||||
config.AllowNativePasswords = true
|
||||
config.ClientFoundRows = true
|
||||
|
||||
// TODO: do we want to support these?
|
||||
// config.ServerPubKey = getter.String("db_server_pub_key")
|
||||
// config.TLSConfig = getter.String("db_tls_config_name")
|
||||
|
||||
if err := getter.Err(); err != nil {
|
||||
return nil, fmt.Errorf("config error: %w", err)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(config.Addr, "/") {
|
||||
config.Net = "unix"
|
||||
}
|
||||
|
||||
// FIXME: get rid of xorm
|
||||
engine, err := xorm.NewEngine(db.DriverMySQL, config.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
|
||||
engine.SetMaxOpenConns(0)
|
||||
engine.SetMaxIdleConns(2)
|
||||
engine.SetConnMaxLifetime(4 * time.Hour)
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func getEnginePostgres(getter *sectionGetter, _ tracing.Tracer) (*xorm.Engine, error) {
|
||||
dsnKV := map[string]string{
|
||||
"user": getter.String("db_user"),
|
||||
"password": getter.String("db_pass"),
|
||||
"dbname": getter.String("db_name"),
|
||||
"sslmode": cmp.Or(getter.String("db_sslmode"), "disable"),
|
||||
}
|
||||
|
||||
// TODO: probably interesting:
|
||||
// "passfile", "statement_timeout", "lock_timeout", "connect_timeout"
|
||||
|
||||
// TODO: for CockroachDB, we probably need to use the following:
|
||||
// dsnKV["options"] = "-c enable_experimental_alter_column_type_general=true"
|
||||
// Or otherwise specify it as:
|
||||
// dsnKV["enable_experimental_alter_column_type_general"] = "true"
|
||||
|
||||
// TODO: do we want to support these options in the DSN as well?
|
||||
// "sslkey", "sslcert", "sslrootcert", "sslpassword", "sslsni", "krbspn",
|
||||
// "krbsrvname", "target_session_attrs", "service", "servicefile"
|
||||
|
||||
// More on Postgres connection string parameters:
|
||||
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
|
||||
|
||||
hostport := getter.String("db_host")
|
||||
|
||||
if err := getter.Err(); err != nil {
|
||||
return nil, fmt.Errorf("config error: %w", err)
|
||||
}
|
||||
|
||||
host, port, err := splitHostPortDefault(hostport, "127.0.0.1", "5432")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid db_host: %w", err)
|
||||
}
|
||||
dsnKV["host"] = host
|
||||
dsnKV["port"] = port
|
||||
|
||||
dsn, err := MakeDSN(dsnKV)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building DSN: %w", err)
|
||||
}
|
||||
|
||||
// FIXME: get rid of xorm
|
||||
engine, err := xorm.NewEngine(db.DriverPostgres, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetEngineMySQLFromConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
getter := newTestSectionGetter(map[string]string{
|
||||
"db_type": "mysql",
|
||||
"db_host": "/var/run/mysql.socket",
|
||||
"db_name": "grafana",
|
||||
"db_user": "user",
|
||||
"db_password": "password",
|
||||
})
|
||||
engine, err := getEngineMySQL(getter, nil)
|
||||
assert.NotNil(t, engine)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
getter := newTestSectionGetter(map[string]string{
|
||||
"db_type": "mysql",
|
||||
"db_host": "/var/run/mysql.socket",
|
||||
"db_name": string(invalidUTF8ByteSequence),
|
||||
"db_user": "user",
|
||||
"db_password": "password",
|
||||
})
|
||||
engine, err := getEngineMySQL(getter, nil)
|
||||
assert.Nil(t, engine)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrInvalidUTF8Sequence)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetEnginePostgresFromConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
getter := newTestSectionGetter(map[string]string{
|
||||
"db_type": "mysql",
|
||||
"db_host": "localhost",
|
||||
"db_name": "grafana",
|
||||
"db_user": "user",
|
||||
"db_password": "password",
|
||||
})
|
||||
engine, err := getEnginePostgres(getter, nil)
|
||||
|
||||
assert.NotNil(t, engine)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
getter := newTestSectionGetter(map[string]string{
|
||||
"db_type": "mysql",
|
||||
"db_host": string(invalidUTF8ByteSequence),
|
||||
"db_name": "grafana",
|
||||
"db_user": "user",
|
||||
"db_password": "password",
|
||||
})
|
||||
engine, err := getEnginePostgres(getter, nil)
|
||||
|
||||
assert.Nil(t, engine)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrInvalidUTF8Sequence)
|
||||
})
|
||||
|
||||
t.Run("invalid hostport", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
getter := newTestSectionGetter(map[string]string{
|
||||
"db_type": "mysql",
|
||||
"db_host": "1:1:1",
|
||||
"db_name": "grafana",
|
||||
"db_user": "user",
|
||||
"db_password": "password",
|
||||
})
|
||||
engine, err := getEnginePostgres(getter, nil)
|
||||
|
||||
assert.Nil(t, engine)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
|
||||
)
|
||||
|
||||
func newCtx(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
|
||||
d, ok := t.Deadline()
|
||||
if !ok {
|
||||
// provide a default timeout for tests
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
var errTest = errors.New("because of reasons")
|
||||
|
||||
const driverName = "sqlmock"
|
||||
|
||||
func TestDB_BeginTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, driverName)
|
||||
require.Equal(t, driverName, db.DriverName())
|
||||
|
||||
mock.ExpectBegin()
|
||||
tx, err := db.BeginTx(newCtx(t), nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tx)
|
||||
})
|
||||
|
||||
t.Run("fail begin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, "sqlmock")
|
||||
|
||||
mock.ExpectBegin().WillReturnError(errTest)
|
||||
tx, err := db.BeginTx(newCtx(t), nil)
|
||||
|
||||
require.Nil(t, tx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errTest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDB_WithTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newTxFunc := func(err error) resourcedb.TxFunc {
|
||||
return func(context.Context, resourcedb.Tx) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, "sqlmock")
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectCommit()
|
||||
err = db.WithTx(newCtx(t), nil, newTxFunc(nil))
|
||||
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail begin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, "sqlmock")
|
||||
|
||||
mock.ExpectBegin().WillReturnError(errTest)
|
||||
err = db.WithTx(newCtx(t), nil, newTxFunc(nil))
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errTest)
|
||||
})
|
||||
|
||||
t.Run("fail tx", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, "sqlmock")
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectRollback()
|
||||
err = db.WithTx(newCtx(t), nil, newTxFunc(errTest))
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errTest)
|
||||
})
|
||||
|
||||
t.Run("fail tx; fail rollback", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, "sqlmock")
|
||||
errTest2 := errors.New("yet another err")
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectRollback().WillReturnError(errTest)
|
||||
err = db.WithTx(newCtx(t), nil, newTxFunc(errTest2))
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errTest)
|
||||
require.ErrorIs(t, err, errTest2)
|
||||
})
|
||||
|
||||
t.Run("fail commit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqldb, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
db := NewDB(sqldb, "sqlmock")
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectCommit().WillReturnError(errTest)
|
||||
err = db.WithTx(newCtx(t), nil, newTxFunc(nil))
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errTest)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/dlmiddlecote/sqlstats"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/session"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
resourcedb "github.com/grafana/grafana/pkg/storage/unified/sql/db"
|
||||
"github.com/grafana/grafana/pkg/storage/unified/sql/db/migrations"
|
||||
)
|
||||
|
||||
var _ resourcedb.ResourceDBInterface = (*ResourceDB)(nil)
|
||||
|
||||
func ProvideResourceDB(db db.DB, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer tracing.Tracer) (*ResourceDB, error) {
|
||||
return &ResourceDB{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
features: features,
|
||||
log: log.New("entity-db"),
|
||||
tracer: tracer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ResourceDB struct {
|
||||
once sync.Once
|
||||
onceErr error
|
||||
|
||||
db db.DB
|
||||
features featuremgmt.FeatureToggles
|
||||
engine *xorm.Engine
|
||||
cfg *setting.Cfg
|
||||
log log.Logger
|
||||
tracer tracing.Tracer
|
||||
}
|
||||
|
||||
func (db *ResourceDB) Init() error {
|
||||
db.once.Do(func() {
|
||||
db.onceErr = db.init()
|
||||
})
|
||||
|
||||
return db.onceErr
|
||||
}
|
||||
|
||||
func (db *ResourceDB) GetEngine() (*xorm.Engine, error) {
|
||||
if err := db.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.engine, db.onceErr
|
||||
}
|
||||
|
||||
func (db *ResourceDB) init() error {
|
||||
if db.engine != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var engine *xorm.Engine
|
||||
var err error
|
||||
|
||||
// TODO: This should be renamed resource_api
|
||||
getter := §ionGetter{
|
||||
DynamicSection: db.cfg.SectionWithEnvOverrides("resource_api"),
|
||||
}
|
||||
|
||||
dbType := getter.Key("db_type").MustString("")
|
||||
|
||||
// if explicit connection settings are provided, use them
|
||||
if dbType != "" {
|
||||
if dbType == "postgres" {
|
||||
engine, err = getEnginePostgres(getter, db.tracer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// FIXME: this config option is cockroachdb-specific, it's not supported by postgres
|
||||
// FIXME: this only sets this option for the session that we get
|
||||
// from the pool right now. A *sql.DB is a pool of connections,
|
||||
// there is no guarantee that the session where this is run will be
|
||||
// the same where we need to change the type of a column
|
||||
_, err = engine.Exec("SET SESSION enable_experimental_alter_column_type_general=true")
|
||||
if err != nil {
|
||||
db.log.Error("error connecting to postgres", "msg", err.Error())
|
||||
// FIXME: return nil, err
|
||||
}
|
||||
} else if dbType == "mysql" {
|
||||
engine, err = getEngineMySQL(getter, db.tracer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = engine.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// TODO: sqlite support
|
||||
return fmt.Errorf("invalid db type specified: %s", dbType)
|
||||
}
|
||||
|
||||
// register sql stat metrics
|
||||
if err := prometheus.Register(sqlstats.NewStatsCollector("unified_storage", engine.DB().DB)); err != nil {
|
||||
db.log.Warn("Failed to register unified storage sql stats collector", "error", err)
|
||||
}
|
||||
|
||||
// configure sql logging
|
||||
debugSQL := getter.Key("log_queries").MustBool(false)
|
||||
if !debugSQL {
|
||||
engine.SetLogger(&xorm.DiscardLogger{})
|
||||
} else {
|
||||
// add stack to database calls to be able to see what repository initiated queries. Top 7 items from the stack as they are likely in the xorm library.
|
||||
// engine.SetLogger(sqlstore.NewXormLogger(log.LvlInfo, log.WithSuffix(log.New("sqlstore.xorm"), log.CallerContextKey, log.StackCaller(log.DefaultCallerDepth))))
|
||||
engine.ShowSQL(true)
|
||||
engine.ShowExecTime(true)
|
||||
}
|
||||
|
||||
// otherwise, try to use the grafana db connection
|
||||
} else {
|
||||
if db.db == nil {
|
||||
return fmt.Errorf("no db connection provided")
|
||||
}
|
||||
|
||||
engine = db.db.GetEngine()
|
||||
}
|
||||
|
||||
db.engine = engine
|
||||
|
||||
if err := migrations.MigrateResourceStore(engine, db.cfg, db.features); err != nil {
|
||||
db.engine = nil
|
||||
return fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ResourceDB) GetSession() (*session.SessionDB, error) {
|
||||
engine, err := db.GetEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.GetSession(sqlx.NewDb(engine.DB().DB, engine.DriverName())), nil
|
||||
}
|
||||
|
||||
func (db *ResourceDB) GetCfg() *setting.Cfg {
|
||||
return db.cfg
|
||||
}
|
||||
|
||||
func (db *ResourceDB) GetDB() (resourcedb.DB, error) {
|
||||
engine, err := db.GetEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := NewDB(engine.DB().DB, engine.Dialect().DriverName())
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUTF8Sequence = errors.New("invalid UTF-8 sequence")
|
||||
)
|
||||
|
||||
type sectionGetter struct {
|
||||
*setting.DynamicSection
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *sectionGetter) Err() error {
|
||||
return g.err
|
||||
}
|
||||
|
||||
func (g *sectionGetter) String(key string) string {
|
||||
v := g.DynamicSection.Key(key).MustString("")
|
||||
if !utf8.ValidString(v) {
|
||||
g.err = fmt.Errorf("value for key %q: %w", key, ErrInvalidUTF8Sequence)
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// MakeDSN creates a DSN from the given key/value pair. It validates the strings
|
||||
// form valid UTF-8 sequences and escapes values if needed.
|
||||
func MakeDSN(m map[string]string) (string, error) {
|
||||
b := new(strings.Builder)
|
||||
|
||||
ks := keys(m)
|
||||
sort.Strings(ks) // provide deterministic behaviour
|
||||
for _, k := range ks {
|
||||
v := m[k]
|
||||
if !utf8.ValidString(v) {
|
||||
return "", fmt.Errorf("value for DSN key %q: %w", k,
|
||||
ErrInvalidUTF8Sequence)
|
||||
}
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if b.Len() > 0 {
|
||||
_ = b.WriteByte(' ')
|
||||
}
|
||||
_, _ = b.WriteString(k)
|
||||
_ = b.WriteByte('=')
|
||||
writeDSNValue(b, v)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func keys(m map[string]string) []string {
|
||||
ret := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
ret = append(ret, k)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func writeDSNValue(b *strings.Builder, v string) {
|
||||
numq := strings.Count(v, `'`)
|
||||
numb := strings.Count(v, `\`)
|
||||
if numq+numb == 0 && v != "" {
|
||||
b.WriteString(v)
|
||||
|
||||
return
|
||||
}
|
||||
b.Grow(2 + numq + numb + len(v))
|
||||
|
||||
_ = b.WriteByte('\'')
|
||||
for _, r := range v {
|
||||
if r == '\\' || r == '\'' {
|
||||
_ = b.WriteByte('\\')
|
||||
}
|
||||
_, _ = b.WriteRune(r)
|
||||
}
|
||||
_ = b.WriteByte('\'')
|
||||
}
|
||||
|
||||
// splitHostPortDefault is similar to net.SplitHostPort, but will also accept a
|
||||
// specification with no port and apply the default port instead. It also
|
||||
// applies the given defaults if the results are empty strings.
|
||||
func splitHostPortDefault(hostport, defaultHost, defaultPort string) (string, string, error) {
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
// try appending the port
|
||||
host, port, err = net.SplitHostPort(hostport + ":" + defaultPort)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("invalid hostport: %q", hostport)
|
||||
}
|
||||
}
|
||||
host = cmp.Or(host, defaultHost)
|
||||
port = cmp.Or(port, defaultPort)
|
||||
|
||||
return host, port, nil
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var invalidUTF8ByteSequence = []byte{0xff, 0xfe, 0xfd}
|
||||
|
||||
func setSectionKeyValues(section *setting.DynamicSection, m map[string]string) {
|
||||
for k, v := range m {
|
||||
section.Key(k).SetValue(v)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestSectionGetter(m map[string]string) *sectionGetter {
|
||||
section := setting.NewCfg().SectionWithEnvOverrides("entity_api")
|
||||
setSectionKeyValues(section, m)
|
||||
|
||||
return §ionGetter{
|
||||
DynamicSection: section,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSectionGetter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
key = "the key"
|
||||
val = string(invalidUTF8ByteSequence)
|
||||
)
|
||||
|
||||
g := newTestSectionGetter(map[string]string{
|
||||
key: val,
|
||||
})
|
||||
|
||||
v := g.String("whatever")
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, g.Err())
|
||||
|
||||
v = g.String(key)
|
||||
require.Empty(t, v)
|
||||
require.Error(t, g.Err())
|
||||
require.ErrorIs(t, g.Err(), ErrInvalidUTF8Sequence)
|
||||
}
|
||||
|
||||
func TestMakeDSN(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s, err := MakeDSN(map[string]string{
|
||||
"db_name": string(invalidUTF8ByteSequence),
|
||||
})
|
||||
require.Empty(t, s)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrInvalidUTF8Sequence)
|
||||
|
||||
s, err = MakeDSN(map[string]string{
|
||||
"skip": "",
|
||||
"user": `shou'ld esc\ape`,
|
||||
"pass": "noescape",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `pass=noescape user='shou\'ld esc\\ape'`, s)
|
||||
}
|
||||
|
||||
func TestSplitHostPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
hostport string
|
||||
defaultHost string
|
||||
defaultPort string
|
||||
fails bool
|
||||
|
||||
host string
|
||||
port string
|
||||
}{
|
||||
{hostport: "192.168.0.140:456", defaultHost: "", defaultPort: "", host: "192.168.0.140", port: "456"},
|
||||
{hostport: "192.168.0.140", defaultHost: "", defaultPort: "123", host: "192.168.0.140", port: "123"},
|
||||
{hostport: "[::1]:456", defaultHost: "", defaultPort: "", host: "::1", port: "456"},
|
||||
{hostport: "[::1]", defaultHost: "", defaultPort: "123", host: "::1", port: "123"},
|
||||
{hostport: ":456", defaultHost: "1.2.3.4", defaultPort: "", host: "1.2.3.4", port: "456"},
|
||||
{hostport: "xyz.rds.amazonaws.com", defaultHost: "", defaultPort: "123", host: "xyz.rds.amazonaws.com", port: "123"},
|
||||
{hostport: "xyz.rds.amazonaws.com:123", defaultHost: "", defaultPort: "", host: "xyz.rds.amazonaws.com", port: "123"},
|
||||
{hostport: "", defaultHost: "localhost", defaultPort: "1433", host: "localhost", port: "1433"},
|
||||
{hostport: "1:1:1", fails: true},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test index #%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
host, port, err := splitHostPortDefault(tc.hostport, tc.defaultHost, tc.defaultPort)
|
||||
if tc.fails {
|
||||
require.Error(t, err)
|
||||
require.Empty(t, host)
|
||||
require.Empty(t, port)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.host, host)
|
||||
require.Equal(t, tc.port, port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user