SQL Expressions: Re-implement feature using go-mysql-server (#99521)

* Under feature flag `sqlExpressions` and is experimental
* Excluded from arm32
* Will not work with the Query Service yet
* Does not have limits in place yet
* Does not working with alerting yet
* Currently requires "prepare time series" Transform for time series viz
 
---------

Co-authored-by: Sam Jewell <sam.jewell@grafana.com>
This commit is contained in:
Kyle Brandt
2025-02-06 07:27:28 -05:00
committed by GitHub
parent 4e6bdce41c
commit d64f41afdc
33 changed files with 1969 additions and 405 deletions
+50 -11
View File
@@ -1,22 +1,61 @@
//go:build !arm
package sql
import (
"errors"
"context"
sqle "github.com/dolthub/go-mysql-server"
mysql "github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/grafana/grafana-plugin-sdk-go/data"
)
type DB struct {
}
// DB is a database that can execute SQL queries against a set of Frames.
type DB struct{}
func (db *DB) RunCommands(commands []string) (string, error) {
return "", errors.New("not implemented")
}
// QueryFrames runs the sql query query against a database created from frames, and returns the frame.
// The RefID of each frame becomes a table in the database.
// It is expected that there is only one frame per RefID.
// The name becomes the name and RefID of the returned frame.
func (db *DB) QueryFrames(ctx context.Context, name string, query string, frames []*data.Frame) (*data.Frame, error) {
// We are parsing twice due to TablesList, but don't care fow now. We can save the parsed query and reuse it later if we want.
if allow, err := AllowQuery(query); err != nil || !allow {
if err != nil {
return nil, err
}
return nil, err
}
func (db *DB) QueryFramesInto(name string, query string, frames []*data.Frame, f *data.Frame) error {
return errors.New("not implemented")
}
pro := NewFramesDBProvider(frames)
session := mysql.NewBaseSession()
mCtx := mysql.NewContext(ctx, mysql.WithSession(session))
func NewInMemoryDB() *DB {
return &DB{}
// Select the database in the context
mCtx.SetCurrentDatabase(dbName)
// Empty dir does not disable secure_file_priv
//ctx.SetSessionVariable(ctx, "secure_file_priv", "")
// TODO: Check if it's wise to reuse the existing provider, rather than creating a new one
a := analyzer.NewDefault(pro)
engine := sqle.New(a, &sqle.Config{
IsReadOnly: true,
})
schema, iter, _, err := engine.Query(mCtx, query)
if err != nil {
return nil, err
}
f, err := convertToDataFrame(mCtx, iter, schema)
if err != nil {
return nil, err
}
f.Name = name
f.RefID = name
return f, nil
}
+187
View File
@@ -0,0 +1,187 @@
//go:build !arm
package sql
import (
"context"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/stretchr/testify/require"
)
func TestQueryFrames(t *testing.T) {
db := DB{}
tests := []struct {
name string
query string
input_frames []*data.Frame
expected *data.Frame
}{
{
name: "valid query with no input frames, one row one column",
query: `SELECT '1' AS 'n';`,
input_frames: []*data.Frame{},
expected: data.NewFrame(
"sqlExpressionRefId",
data.NewField("n", nil, []string{"1"}),
),
},
{
name: "valid query with no input frames, one row two columns",
query: `SELECT 'sam' AS 'name', 40 AS 'age';`,
input_frames: []*data.Frame{},
expected: data.NewFrame(
"sqlExpressionRefId",
data.NewField("name", nil, []string{"sam"}),
data.NewField("age", nil, []int8{40}),
),
},
{
// TODO: Also ORDER BY to ensure the order is preserved
name: "query all rows from single input frame",
query: `SELECT * FROM inputFrameRefId LIMIT 1;`,
input_frames: []*data.Frame{
setRefID(data.NewFrame(
"",
//nolint:misspell
data.NewField("OSS Projects with Typos", nil, []string{"Garfana", "Pormetheus"}),
), "inputFrameRefId"),
},
expected: data.NewFrame(
"sqlExpressionRefId",
data.NewField("OSS Projects with Typos", nil, []string{"Garfana"}),
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
frame, err := db.QueryFrames(context.Background(), "sqlExpressionRefId", tt.query, tt.input_frames)
require.NoError(t, err)
require.NotNil(t, frame.Fields)
require.Equal(t, tt.expected.Name, frame.RefID)
require.Equal(t, len(tt.expected.Fields), len(frame.Fields))
for i := range tt.expected.Fields {
require.Equal(t, tt.expected.Fields[i].Name, frame.Fields[i].Name)
require.Equal(t, tt.expected.Fields[i].At(0), frame.Fields[i].At(0))
}
})
}
}
func TestQueryFramesInOut(t *testing.T) {
frameA := &data.Frame{
RefID: "a",
Name: "a",
Fields: []*data.Field{
data.NewField("time", nil, []time.Time{time.Now(), time.Now()}),
data.NewField("time_nullable", nil, []*time.Time{p(time.Now()), nil}),
data.NewField("string", nil, []string{"cat", "dog"}),
data.NewField("null_nullable", nil, []*string{p("cat"), nil}),
data.NewField("float64", nil, []float64{1, 3}),
data.NewField("float64_nullable", nil, []*float64{p(2.0), nil}),
data.NewField("int64", nil, []int64{1, 3}),
data.NewField("int64_nullable", nil, []*int64{p(int64(2)), nil}),
data.NewField("bool", nil, []bool{true, false}),
data.NewField("bool_nullable", nil, []*bool{p(true), nil}),
},
}
db := DB{}
qry := `SELECT * from a`
resultFrame, err := db.QueryFrames(context.Background(), "a", qry, []*data.Frame{frameA})
require.NoError(t, err)
if diff := cmp.Diff(frameA, resultFrame, data.FrameTestCompareOptions()...); diff != "" {
require.FailNowf(t, "Result mismatch (-want +got):%s\n", diff)
}
}
func TestQueryFramesNumericSelect(t *testing.T) {
expectedFrame := &data.Frame{
RefID: "a",
Name: "a",
Fields: []*data.Field{
data.NewField("decimal", nil, []float64{2.35}),
data.NewField("tinySigned", nil, []int8{-128}),
data.NewField("smallSigned", nil, []int16{-32768}),
data.NewField("mediumSigned", nil, []int32{-8388608}),
data.NewField("intSigned", nil, []int32{-2147483648}),
data.NewField("bigSigned", nil, []int64{-9223372036854775808}),
data.NewField("tinyUnsigned", nil, []uint8{255}),
data.NewField("smallUnsigned", nil, []uint16{65535}),
data.NewField("mediumUnsigned", nil, []int32{16777215}),
data.NewField("intUnsigned", nil, []uint32{4294967295}),
data.NewField("bigUnsigned", nil, []uint64{18446744073709551615}),
},
}
db := DB{}
qry := `SELECT 2.35 AS 'decimal',
-128 AS 'tinySigned',
-32768 AS 'smallSigned',
-8388608 AS 'mediumSigned',
-2147483648 AS 'intSigned',
-9223372036854775808 AS 'bigSigned',
255 AS 'tinyUnsigned',
65535 AS 'smallUnsigned',
16777215 AS 'mediumUnsigned',
4294967295 AS 'intUnsigned',
18446744073709551615 AS 'bigUnsigned'`
resultFrame, err := db.QueryFrames(context.Background(), "a", qry, []*data.Frame{})
require.NoError(t, err)
if diff := cmp.Diff(expectedFrame, resultFrame, data.FrameTestCompareOptions()...); diff != "" {
require.FailNowf(t, "Result mismatch (-want +got):%s\n", diff)
}
}
func TestQueryFramesDateTimeSelect(t *testing.T) {
t.Skip("need a fix in go-mysql-server, and then handle the datetime strings (or figure out why strings and not time.Time)")
expectedFrame := &data.Frame{
RefID: "a",
Name: "a",
Fields: []*data.Field{
data.NewField("ts", nil, []time.Time{}),
},
}
db := DB{}
// It doesn't like the T in the time string
qry := `SELECT str_to_date('2025-02-03T03:00:00','%Y-%m-%dT%H:%i:%s') as ts`
// This comes back as a string, which needs to be dealt with?
//qry := `SELECT str_to_date('2025-02-03-03:00:00','%Y-%m-%d-%H:%i:%s') as ts`
// This is a datetime(6), need to deal with that as well
//qry := `SELECT current_timestamp() as ts`
f, err := db.QueryFrames(context.Background(), "b", qry, []*data.Frame{})
require.NoError(t, err)
if diff := cmp.Diff(expectedFrame, f, data.FrameTestCompareOptions()...); diff != "" {
require.FailNowf(t, "Result mismatch (-want +got):%s\n", diff)
}
}
// p is a utility for pointers from constants
func p[T any](v T) *T {
return &v
}
func setRefID(f *data.Frame, refID string) *data.Frame {
f.RefID = refID
return f
}
+18
View File
@@ -0,0 +1,18 @@
//go:build arm
package sql
import (
"context"
"fmt"
"github.com/grafana/grafana-plugin-sdk-go/data"
)
type DB struct{}
// Stub out the QueryFrames method for ARM builds
// See github.com/dolthub/go-mysql-server/issues/2837
func (db *DB) QueryFrames(_ context.Context, _, _ string, _ []*data.Frame) (*data.Frame, error) {
return nil, fmt.Errorf("sql expressions not supported in arm")
}
+65
View File
@@ -0,0 +1,65 @@
//go:build !arm
package sql
import (
mysql "github.com/dolthub/go-mysql-server/sql"
"github.com/grafana/grafana-plugin-sdk-go/data"
)
var dbName = "frames"
// FramesDBProvider is a go-mysql-server DatabaseProvider that provides access to a set of Frames.
type FramesDBProvider struct {
db mysql.Database
}
func (p *FramesDBProvider) Database(_ *mysql.Context, _ string) (mysql.Database, error) {
return p.db, nil
}
func (p *FramesDBProvider) HasDatabase(_ *mysql.Context, _ string) bool {
return true
}
func (p *FramesDBProvider) AllDatabases(_ *mysql.Context) []mysql.Database {
return []mysql.Database{p.db}
}
// NewFramesDBProvider creates a new FramesDBProvider with the given set of Frames.
func NewFramesDBProvider(frames data.Frames) mysql.DatabaseProvider {
fMap := make(map[string]mysql.Table, len(frames))
for _, frame := range frames {
fMap[frame.RefID] = &FrameTable{Frame: frame}
}
return &FramesDBProvider{
db: &framesDB{
frames: fMap,
},
}
}
// framesDB is a go-mysql-server Database that provides access to a set of Frames.
type framesDB struct {
frames map[string]mysql.Table
}
func (db *framesDB) GetTableInsensitive(_ *mysql.Context, tblName string) (mysql.Table, bool, error) {
tbl, ok := mysql.GetTableInsensitive(tblName, db.frames)
if !ok {
return nil, false, nil
}
return tbl, ok, nil
}
func (db *framesDB) GetTableNames(_ *mysql.Context) ([]string, error) {
s := make([]string, 0, len(db.frames))
for k := range db.frames {
s = append(s, k)
}
return s, nil
}
func (db *framesDB) Name() string {
return dbName
}
+474
View File
@@ -0,0 +1,474 @@
//go:build !arm
package sql
import (
"errors"
"fmt"
"io"
"time"
mysql "github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/shopspring/decimal"
)
// TODO: Should this accept a row limit and converters, like sqlutil.FrameFromRows?
func convertToDataFrame(ctx *mysql.Context, iter mysql.RowIter, schema mysql.Schema) (*data.Frame, error) {
f := &data.Frame{}
// Create fields based on the schema
for _, col := range schema {
fT, err := MySQLColToFieldType(col)
if err != nil {
return nil, err
}
field := data.NewFieldFromFieldType(fT, 0)
field.Name = col.Name
f.Fields = append(f.Fields, field)
}
// Iterate through the rows and append data to fields
for {
row, err := iter.Next(ctx)
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("error reading row: %v", err)
}
for i, val := range row {
v, err := fieldValFromRowVal(f.Fields[i].Type(), val)
if err != nil {
return nil, fmt.Errorf("unexpected type for column %s: %w", schema[i].Name, err)
}
f.Fields[i].Append(v)
}
}
return f, nil
}
// MySQLColToFieldType converts a MySQL column to a data.FieldType
func MySQLColToFieldType(col *mysql.Column) (data.FieldType, error) {
var fT data.FieldType
switch col.Type {
case types.Int8:
fT = data.FieldTypeInt8
case types.Uint8:
fT = data.FieldTypeUint8
case types.Int16:
fT = data.FieldTypeInt16
case types.Uint16:
fT = data.FieldTypeUint16
case types.Int32:
fT = data.FieldTypeInt32
case types.Uint32:
fT = data.FieldTypeUint32
case types.Int64:
fT = data.FieldTypeInt64
case types.Uint64:
fT = data.FieldTypeUint64
case types.Float64:
fT = data.FieldTypeFloat64
// StringType represents all string types, including VARCHAR and BLOB.
case types.Text, types.LongText:
fT = data.FieldTypeString
case types.Timestamp:
fT = data.FieldTypeTime
case types.Datetime:
fT = data.FieldTypeTime
case types.Boolean:
fT = data.FieldTypeBool
default:
if types.IsDecimal(col.Type) {
fT = data.FieldTypeFloat64
} else {
return fT, fmt.Errorf("unsupported type for column %s of type %v", col.Name, col.Type)
}
}
if col.Nullable {
fT = fT.NullableType()
}
return fT, nil
}
// Helper function to convert data.FieldType to types.Type
func convertDataType(fieldType data.FieldType) mysql.Type {
switch fieldType {
case data.FieldTypeInt8, data.FieldTypeNullableInt8:
return types.Int8
case data.FieldTypeUint8, data.FieldTypeNullableUint8:
return types.Uint8
case data.FieldTypeInt16, data.FieldTypeNullableInt16:
return types.Int16
case data.FieldTypeUint16, data.FieldTypeNullableUint16:
return types.Uint16
case data.FieldTypeInt32, data.FieldTypeNullableInt32:
return types.Int32
case data.FieldTypeUint32, data.FieldTypeNullableUint32:
return types.Uint32
case data.FieldTypeInt64, data.FieldTypeNullableInt64:
return types.Int64
case data.FieldTypeUint64, data.FieldTypeNullableUint64:
return types.Uint64
case data.FieldTypeFloat32, data.FieldTypeNullableFloat32:
return types.Float32
case data.FieldTypeFloat64, data.FieldTypeNullableFloat64:
return types.Float64
case data.FieldTypeString, data.FieldTypeNullableString:
return types.Text
case data.FieldTypeBool, data.FieldTypeNullableBool:
return types.Boolean
case data.FieldTypeTime, data.FieldTypeNullableTime:
return types.Timestamp
default:
fmt.Printf("------- Unsupported field type: %v", fieldType)
return types.JSON
}
}
// fieldValFromRowVal converts a go-mysql-server row value to a data.field value
//
//nolint:gocyclo
func fieldValFromRowVal(fieldType data.FieldType, val interface{}) (interface{}, error) {
// the input val may be nil, it also may not be a pointer even if the fieldtype is a nullable pointer type
if val == nil {
return nil, nil
}
switch fieldType {
// ----------------------------
// Int8 / Nullable Int8
// ----------------------------
case data.FieldTypeInt8:
v, ok := val.(int8)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int8", val, val)
}
return v, nil
case data.FieldTypeNullableInt8:
vP, ok := val.(*int8)
if ok {
return vP, nil
}
v, ok := val.(int8)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int8 or *int8", val, val)
// ----------------------------
// Uint8 / Nullable Uint8
// ----------------------------
case data.FieldTypeUint8:
v, ok := val.(uint8)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint8", val, val)
}
return v, nil
case data.FieldTypeNullableUint8:
vP, ok := val.(*uint8)
if ok {
return vP, nil
}
v, ok := val.(uint8)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint8 or *uint8", val, val)
// ----------------------------
// Int16 / Nullable Int16
// ----------------------------
case data.FieldTypeInt16:
v, ok := val.(int16)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int16", val, val)
}
return v, nil
case data.FieldTypeNullableInt16:
vP, ok := val.(*int16)
if ok {
return vP, nil
}
v, ok := val.(int16)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int16 or *int16", val, val)
// ----------------------------
// Uint16 / Nullable Uint16
// ----------------------------
case data.FieldTypeUint16:
v, ok := val.(uint16)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint16", val, val)
}
return v, nil
case data.FieldTypeNullableUint16:
vP, ok := val.(*uint16)
if ok {
return vP, nil
}
v, ok := val.(uint16)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint16 or *uint16", val, val)
// ----------------------------
// Int32 / Nullable Int32
// ----------------------------
case data.FieldTypeInt32:
v, ok := val.(int32)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int32", val, val)
}
return v, nil
case data.FieldTypeNullableInt32:
vP, ok := val.(*int32)
if ok {
return vP, nil
}
v, ok := val.(int32)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int32 or *int32", val, val)
// ----------------------------
// Uint32 / Nullable Uint32
// ----------------------------
case data.FieldTypeUint32:
v, ok := val.(uint32)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint32", val, val)
}
return v, nil
case data.FieldTypeNullableUint32:
vP, ok := val.(*uint32)
if ok {
return vP, nil
}
v, ok := val.(uint32)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint32 or *uint32", val, val)
// ----------------------------
// Int64 / Nullable Int64
// ----------------------------
case data.FieldTypeInt64:
v, ok := val.(int64)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int64", val, val)
}
return v, nil
case data.FieldTypeNullableInt64:
vP, ok := val.(*int64)
if ok {
return vP, nil
}
v, ok := val.(int64)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int64 or *int64", val, val)
// ----------------------------
// Uint64 / Nullable Uint64
// ----------------------------
case data.FieldTypeUint64:
v, ok := val.(uint64)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint64", val, val)
}
return v, nil
case data.FieldTypeNullableUint64:
vP, ok := val.(*uint64)
if ok {
return vP, nil
}
v, ok := val.(uint64)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint64 or *uint64", val, val)
// ----------------------------
// Float64 / Nullable Float64
// ----------------------------
case data.FieldTypeFloat64:
// Accept float64 or decimal.Decimal, convert decimal.Decimal -> float64
if v, ok := val.(float64); ok {
return v, nil
}
if d, ok := val.(decimal.Decimal); ok {
return d.InexactFloat64(), nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected float64 or decimal.Decimal", val, val)
case data.FieldTypeNullableFloat64:
// Possibly already *float64
if vP, ok := val.(*float64); ok {
return vP, nil
}
// Possibly float64
if v, ok := val.(float64); ok {
return &v, nil
}
// Possibly decimal.Decimal
if d, ok := val.(decimal.Decimal); ok {
f := d.InexactFloat64()
return &f, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected float64, *float64, or decimal.Decimal", val, val)
// ----------------------------
// Time / Nullable Time
// ----------------------------
case data.FieldTypeTime:
v, ok := val.(time.Time)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected time.Time", val, val)
}
return v, nil
case data.FieldTypeNullableTime:
vP, ok := val.(*time.Time)
if ok {
return vP, nil
}
v, ok := val.(time.Time)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected time.Time or *time.Time", val, val)
// ----------------------------
// String / Nullable String
// ----------------------------
case data.FieldTypeString:
v, ok := val.(string)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected string", val, val)
}
return v, nil
case data.FieldTypeNullableString:
vP, ok := val.(*string)
if ok {
return vP, nil
}
v, ok := val.(string)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected string or *string", val, val)
// ----------------------------
// Bool / Nullable Bool
// ----------------------------
case data.FieldTypeBool:
v, ok := val.(bool)
if !ok {
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected bool", val, val)
}
return v, nil
case data.FieldTypeNullableBool:
vP, ok := val.(*bool)
if ok {
return vP, nil
}
v, ok := val.(bool)
if ok {
return &v, nil
}
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected bool or *bool", val, val)
// ----------------------------
// Fallback / Unsupported
// ----------------------------
default:
return nil, fmt.Errorf("unsupported field type %s for val %v", fieldType, val)
}
}
// Is the field nilAt the index. Can panic if out of range.
// TODO: Maybe this should be a method on data.Field?
func nilAt(field data.Field, at int) bool {
if !field.Nullable() {
return false
}
switch field.Type() {
case data.FieldTypeNullableInt8:
v := field.At(at).(*int8)
return v == nil
case data.FieldTypeNullableUint8:
v := field.At(at).(*uint8)
return v == nil
case data.FieldTypeNullableInt16:
v := field.At(at).(*int16)
return v == nil
case data.FieldTypeNullableUint16:
v := field.At(at).(*uint16)
return v == nil
case data.FieldTypeNullableInt32:
v := field.At(at).(*int32)
return v == nil
case data.FieldTypeNullableUint32:
v := field.At(at).(*uint32)
return v == nil
case data.FieldTypeNullableInt64:
v := field.At(at).(*int64)
return v == nil
case data.FieldTypeNullableUint64:
v := field.At(at).(*uint64)
return v == nil
case data.FieldTypeNullableFloat64:
v := field.At(at).(*float64)
return v == nil
case data.FieldTypeNullableString:
v := field.At(at).(*string)
return v == nil
case data.FieldTypeNullableTime:
v := field.At(at).(*time.Time)
return v == nil
case data.FieldTypeNullableBool:
v := field.At(at).(*bool)
return v == nil
default:
// Either it's not a nullable type or it's unsupported
return false
}
}
+126
View File
@@ -0,0 +1,126 @@
//go:build !arm
package sql
import (
"io"
"strings"
mysql "github.com/dolthub/go-mysql-server/sql"
"github.com/grafana/grafana-plugin-sdk-go/data"
)
// FrameTable fulfills the mysql.Table interface for a data.Frame.
type FrameTable struct {
Frame *data.Frame
schema mysql.Schema
}
// Name implements the sql.Nameable interface
func (ft *FrameTable) Name() string {
return ft.Frame.RefID
}
// String implements the fmt.Stringer interface
func (ft *FrameTable) String() string {
return ft.Name()
}
func schemaFromFrame(frame *data.Frame) mysql.Schema {
schema := make(mysql.Schema, len(frame.Fields))
for i, field := range frame.Fields {
schema[i] = &mysql.Column{
Name: field.Name,
Type: convertDataType(field.Type()),
Nullable: field.Type().Nullable(),
Source: strings.ToLower(frame.RefID),
}
}
return schema
}
// Schema implements the mysql.Table interface
func (ft *FrameTable) Schema() mysql.Schema {
if ft.schema == nil {
ft.schema = schemaFromFrame(ft.Frame)
}
return ft.schema
}
// Collation implements the mysql.Table interface
func (ft *FrameTable) Collation() mysql.CollationID {
return mysql.Collation_Unspecified
}
// Partitions implements the mysql.Table interface
func (ft *FrameTable) Partitions(ctx *mysql.Context) (mysql.PartitionIter, error) {
return &noopPartitionIter{}, nil
}
// PartitionRows implements the mysql.Table interface
func (ft *FrameTable) PartitionRows(ctx *mysql.Context, _ mysql.Partition) (mysql.RowIter, error) {
return &rowIter{ft: ft, row: 0}, nil
}
type rowIter struct {
ft *FrameTable
row int
}
func (ri *rowIter) Next(_ *mysql.Context) (mysql.Row, error) {
// We assume each field in the Frame has the same number of rows.
numRows := 0
if len(ri.ft.Frame.Fields) > 0 {
numRows = ri.ft.Frame.Fields[0].Len()
}
// If we've already exhausted all rows, return EOF
if ri.row >= numRows {
return nil, io.EOF
}
// Construct a Row (which is []interface{} under the hood) by pulling
// the value from each column at the current row index.
row := make(mysql.Row, len(ri.ft.Frame.Fields))
for colIndex, field := range ri.ft.Frame.Fields {
if nilAt(*field, ri.row) {
continue
}
row[colIndex], _ = field.ConcreteAt(ri.row)
}
ri.row++
return row, nil
}
// Close implements the mysql.RowIter interface.
// In this no-op example, there isn't anything to do here.
func (ri *rowIter) Close(*mysql.Context) error {
return nil
}
type noopPartitionIter struct {
done bool
}
func (i *noopPartitionIter) Next(*mysql.Context) (mysql.Partition, error) {
if !i.done {
i.done = true
return noopParition, nil
}
return nil, io.EOF
}
func (i *noopPartitionIter) Close(*mysql.Context) error {
return nil
}
var noopParition = partition(nil)
type partition []byte
func (p partition) Key() []byte {
return p
}
+37 -64
View File
@@ -1,89 +1,62 @@
package sql
import (
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/jeremywohl/flatten"
)
const (
TABLE_NAME = "table_name"
ERROR = ".error"
ERROR_MESSAGE = ".error_message"
)
var logger = log.New("sql_expr")
// TablesList returns a list of tables for the sql statement
func TablesList(rawSQL string) ([]string, error) {
db := NewInMemoryDB()
rawSQL = strings.Replace(rawSQL, "'", "''", -1)
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL)
ret, err := db.RunCommands([]string{cmd})
stmt, err := sqlparser.Parse(rawSQL)
if err != nil {
logger.Error("error serializing sql", "error", err.Error(), "sql", rawSQL, "cmd", cmd)
return nil, fmt.Errorf("error serializing sql: %s", err.Error())
logger.Error("error parsing sql: %s", err.Error(), "sql", rawSQL)
return nil, fmt.Errorf("error parsing sql: %s", err.Error())
}
ast := []map[string]any{}
err = json.Unmarshal([]byte(ret), &ast)
if err != nil {
logger.Error("error converting json sql to ast", "error", err.Error(), "ret", ret)
return nil, fmt.Errorf("error converting json to ast: %s", err.Error())
}
tables := make(map[string]struct{})
return tablesFromAST(ast)
}
// tablesFromAST returns a list of tables from the ast
func tablesFromAST(ast []map[string]any) ([]string, error) {
flat, err := flatten.Flatten(ast[0], "", flatten.DotStyle)
if err != nil {
logger.Error("error flattening ast", "error", err.Error(), "ast", ast)
return nil, fmt.Errorf("error flattening ast: %s", err.Error())
}
tables := []string{}
for k, v := range flat {
if strings.HasSuffix(k, ERROR) {
v, ok := v.(bool)
if ok && v {
logger.Error("error in sql", "error", k)
return nil, astError(k, flat)
walkSubtree := func(node sqlparser.SQLNode) error {
err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch v := node.(type) {
case *sqlparser.AliasedTableExpr:
if tableName, ok := v.Expr.(sqlparser.TableName); ok {
tables[tableName.Name.String()] = struct{}{}
}
case *sqlparser.TableName:
tables[v.Name.String()] = struct{}{}
}
return true, nil
}, node)
if err != nil {
logger.Error("error walking sql", "error", err, "node", node)
return fmt.Errorf("failed to parse SQL expression: %w", err)
}
if strings.Contains(k, TABLE_NAME) {
table, ok := v.(string)
if ok && !existsInList(table, tables) {
tables = append(tables, v.(string))
}
return nil
}
if err := walkSubtree(stmt); err != nil {
return nil, err
}
result := make([]string, 0, len(tables))
for table := range tables {
// Remove 'dual' table if it exists
// This is a special table in MySQL that always returns a single row with a single column
// See: https://dev.mysql.com/doc/refman/5.7/en/select.html#:~:text=You%20are%20permitted%20to%20specify%20DUAL%20as%20a%20dummy%20table%20name%20in%20situations%20where%20no%20tables%20are%20referenced
if table != "dual" {
result = append(result, table)
}
}
sort.Strings(tables)
sort.Strings(result)
logger.Debug("tables found in sql", "tables", tables)
return tables, nil
}
func astError(k string, flat map[string]any) error {
key := strings.Replace(k, ERROR, "", 1)
message, ok := flat[key+ERROR_MESSAGE]
if !ok {
message = "unknown error in sql"
}
return fmt.Errorf("error in sql: %s", message)
}
func existsInList(table string, list []string) bool {
for _, t := range list {
if t == table {
return true
}
}
return false
return result, nil
}
+136
View File
@@ -0,0 +1,136 @@
package sql
import (
"fmt"
"strings"
"github.com/dolthub/vitess/go/vt/sqlparser"
)
// AllowQuery parses the query and checks it against an allow list of allowed SQL nodes
// and functions.
func AllowQuery(rawSQL string) (bool, error) {
s, err := sqlparser.Parse(rawSQL)
if err != nil {
return false, fmt.Errorf("error parsing sql: %s", err.Error())
}
walkSubtree := func(node sqlparser.SQLNode) error {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
if !allowedNode(node) {
if fT, ok := node.(*sqlparser.FuncExpr); ok {
return false, fmt.Errorf("blocked function %s - not supported in queries", fT.Name)
}
return false, fmt.Errorf("blocked node %T - not supported in queries", node)
}
return true, nil
}, node)
if err != nil {
return fmt.Errorf("failed to parse SQL expression: %w", err)
}
return nil
}
if err := walkSubtree(s); err != nil {
return false, err
}
return true, nil
}
// nolint:gocyclo,nakedret
func allowedNode(node sqlparser.SQLNode) (b bool) {
b = true // so don't have to return true in every case but default
switch v := node.(type) {
case *sqlparser.FuncExpr:
return allowedFunction(v)
case *sqlparser.AsOf:
return
case *sqlparser.AliasedExpr, *sqlparser.AliasedTableExpr:
return
case *sqlparser.BinaryExpr:
return
case sqlparser.ColIdent, *sqlparser.ColName, sqlparser.Columns:
return
case sqlparser.Comments: // TODO: understand why some are pointer vs not
return
case *sqlparser.CommonTableExpr:
return
case *sqlparser.ComparisonExpr:
return
case *sqlparser.ConvertExpr:
return
case sqlparser.GroupBy:
return
case *sqlparser.IndexHints:
return
case *sqlparser.Into:
return
case *sqlparser.JoinTableExpr, sqlparser.JoinCondition:
return
case *sqlparser.Select, sqlparser.SelectExprs:
return
case *sqlparser.StarExpr:
return
case *sqlparser.SQLVal:
return
case *sqlparser.Limit:
return
case *sqlparser.Order, sqlparser.OrderBy:
return
case *sqlparser.Over:
return
case *sqlparser.Subquery:
return
case sqlparser.TableName, sqlparser.TableExprs, sqlparser.TableIdent:
return
case *sqlparser.With:
return
case *sqlparser.Where:
return
default:
return false
}
}
// nolint:gocyclo,nakedret
func allowedFunction(f *sqlparser.FuncExpr) (b bool) {
b = true // so don't have to return true in every case but default
switch strings.ToLower(f.Name.String()) {
case "sum", "avg", "count", "min", "max":
return
case "coalesce":
return
default:
return false
}
}
+80
View File
@@ -0,0 +1,80 @@
package sql
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAllowQuery(t *testing.T) {
testCases := []struct {
name string
q string
err error
}{
{
name: "a big catch all for now",
q: example_metrics_query,
err: nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AllowQuery(tc.q)
if tc.err != nil {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
var example_metrics_query = `WITH
metrics_this_month AS (
SELECT
Month,
namespace,
sum(BillableSeries) AS billable_series
FROM metrics
WHERE
Month = "2024-11"
GROUP BY
Month,
namespace
ORDER BY billable_series DESC
),
total_metrics AS (
SELECT SUM(billable_series) AS metrics_billable_series_total
FROM metrics_this_month
),
total_traces AS (
-- "usage" is a reserved keyword in MySQL. Quote it with backticks.
SELECT SUM(value) AS traces_usage_total
FROM traces
),
usage_by_team AS (
SELECT
COALESCE(teams.team, 'unaccounted') AS team,
1 + 0 AS team_count,
-- Metrics
SUM(COALESCE(metrics_this_month.billable_series, 0)) AS metrics_billable_series,
-- Traces
SUM(COALESCE(traces.value, 0)) AS traces_usage
-- FROM teams
-- FULL OUTER JOIN metrics_this_month
FROM metrics_this_month
FULL OUTER JOIN teams
ON teams.namespace = metrics_this_month.namespace
FULL OUTER JOIN traces
ON teams.namespace = traces.namespace
GROUP BY
-- COALESCE(teams.team, 'unaccounted')
teams.team
ORDER BY metrics_billable_series DESC
)
SELECT *
FROM usage_by_team
CROSS JOIN total_metrics
CROSS JOIN total_traces`
+124 -207
View File
@@ -3,214 +3,131 @@ package sql
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParse(t *testing.T) {
t.Skip()
sql := "select * from foo"
tables, err := TablesList((sql))
assert.Nil(t, err)
func TestTablesList(t *testing.T) {
tests := []struct {
name string
sql string
expected []string
expectError bool
}{
{
name: "simple select",
sql: "select * from foo",
expected: []string{"foo"},
},
{
name: "select with comma",
sql: "select * from foo,bar",
expected: []string{"bar", "foo"},
},
{
name: "select with multiple commas",
sql: "select * from foo,bar,baz",
expected: []string{"bar", "baz", "foo"},
},
{
name: "no table",
sql: "select 1 as 'n'",
expected: []string{},
},
{
name: "json array",
sql: "SELECT JSON_ARRAY(1, 2, 3) AS array_value",
expected: []string{},
},
{
name: "json extract",
sql: "SELECT JSON_EXTRACT(JSON_ARRAY(1, 2, 3), '$[0]') AS first_element;",
expected: []string{},
},
{
name: "json int array",
sql: "SELECT JSON_ARRAY(3, 2, 1) AS int_array;",
expected: []string{},
},
{
name: "subquery",
sql: "select * from (select * from people limit 1) AS subquery",
expected: []string{"people"},
},
{
name: "join",
sql: `select * from A
JOIN B ON A.name = B.name
LIMIT 10`,
expected: []string{"A", "B"},
},
{
name: "right join",
sql: `select * from A
RIGHT JOIN B ON A.name = B.name
LIMIT 10`,
expected: []string{"A", "B"},
},
{
name: "alias with join",
sql: `select * from A as X
RIGHT JOIN B ON A.name = X.name
LIMIT 10`,
expected: []string{"A", "B"},
},
{
name: "alias",
sql: "select * from A as X LIMIT 10",
expected: []string{"A"},
},
{
name: "error case",
sql: "select * from zzz aaa zzz",
expectError: true,
},
{
name: "parens",
sql: `SELECT t1.Col1,
t2.Col1,
t3.Col1
FROM table1 AS t1
LEFT JOIN (
table2 AS t2
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
) ON t2.Col1 = t1.Col1;`,
expected: []string{"table1", "table2", "table3"},
},
{
name: "with clause",
sql: `WITH top_products AS (
SELECT * FROM products
ORDER BY price DESC
LIMIT 5
)
SELECT name, price
FROM top_products;`,
expected: []string{"products", "top_products"},
},
{
name: "with quote",
sql: "select *,'junk' from foo",
expected: []string{"foo"},
},
{
name: "with quote 2",
sql: "SELECT json_serialize_sql('SELECT 1')",
expected: []string{},
},
}
assert.Equal(t, "foo", tables[0])
}
func TestParseWithComma(t *testing.T) {
t.Skip()
sql := "select * from foo,bar"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, "bar", tables[0])
assert.Equal(t, "foo", tables[1])
}
func TestParseWithCommas(t *testing.T) {
t.Skip()
sql := "select * from foo,bar,baz"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, "bar", tables[0])
assert.Equal(t, "baz", tables[1])
assert.Equal(t, "foo", tables[2])
}
func TestArray(t *testing.T) {
t.Skip()
sql := "SELECT array_value(1, 2, 3)"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
}
func TestArray2(t *testing.T) {
t.Skip()
sql := "SELECT array_value(1, 2, 3)[2]"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
}
func TestXxx(t *testing.T) {
t.Skip()
sql := "SELECT [3, 2, 1]::INT[3];"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
}
func TestParseSubquery(t *testing.T) {
t.Skip()
sql := "select * from (select * from people limit 1)"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 1, len(tables))
assert.Equal(t, "people", tables[0])
}
func TestJoin(t *testing.T) {
t.Skip()
sql := `select * from A
JOIN B ON A.name = B.name
LIMIT 10`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 2, len(tables))
assert.Equal(t, "A", tables[0])
assert.Equal(t, "B", tables[1])
}
func TestRightJoin(t *testing.T) {
t.Skip()
sql := `select * from A
RIGHT JOIN B ON A.name = B.name
LIMIT 10`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 2, len(tables))
assert.Equal(t, "A", tables[0])
assert.Equal(t, "B", tables[1])
}
func TestAliasWithJoin(t *testing.T) {
t.Skip()
sql := `select * from A as X
RIGHT JOIN B ON A.name = X.name
LIMIT 10`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 2, len(tables))
assert.Equal(t, "A", tables[0])
assert.Equal(t, "B", tables[1])
}
func TestAlias(t *testing.T) {
t.Skip()
sql := `select * from A as X LIMIT 10`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 1, len(tables))
assert.Equal(t, "A", tables[0])
}
func TestError(t *testing.T) {
t.Skip()
sql := `select * from zzz aaa zzz`
_, err := TablesList((sql))
assert.NotNil(t, err)
}
func TestParens(t *testing.T) {
t.Skip()
sql := `SELECT t1.Col1,
t2.Col1,
t3.Col1
FROM table1 AS t1
LEFT JOIN (
table2 AS t2
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
) ON t2.Col1 = t1.Col1;`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 3, len(tables))
assert.Equal(t, "table1", tables[0])
assert.Equal(t, "table2", tables[1])
assert.Equal(t, "table3", tables[2])
}
func TestWith(t *testing.T) {
t.Skip()
sql := `WITH
current_month AS (
select
distinct "Month(ISO)" as mth
from A
ORDER BY mth DESC
LIMIT 1
),
last_month_bill AS (
select
CAST (
sum(
CAST(BillableSeries AS INTEGER)
) AS INTEGER
) AS BillableSeries,
"Month(ISO)",
label_namespace
-- , B.activeseries_count
from A
JOIN current_month
ON current_month.mth = A."Month(ISO)"
JOIN B
ON B.namespace = A.label_namespace
GROUP BY
label_namespace,
"Month(ISO)"
ORDER BY BillableSeries DESC
)
SELECT
last_month_bill.*,
BEE.activeseries_count
FROM last_month_bill
JOIN BEE
ON BEE.namespace = last_month_bill.label_namespace`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 5, len(tables))
assert.Equal(t, "A", tables[0])
assert.Equal(t, "B", tables[1])
assert.Equal(t, "BEE", tables[2])
}
func TestWithQuote(t *testing.T) {
t.Skip()
sql := "select *,'junk' from foo"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, "foo", tables[0])
}
func TestWithQuote2(t *testing.T) {
t.Skip()
sql := "SELECT json_serialize_sql('SELECT 1')"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tables, err := TablesList(tc.sql)
if tc.expectError {
require.NotNil(t, err, "expected error for SQL: %s", tc.sql)
} else {
require.Nil(t, err, "unexpected error for SQL: %s", tc.sql)
require.Equal(t, tc.expected, tables, "mismatched tables for SQL: %s", tc.sql)
}
})
}
}