SQL Expressions: Re-implement feature using go-mysql-server (#99521)
* Under feature flag `sqlExpressions` and is experimental * Excluded from arm32 * Will not work with the Query Service yet * Does not have limits in place yet * Does not working with alerting yet * Currently requires "prepare time series" Transform for time series viz --------- Co-authored-by: Sam Jewell <sam.jewell@grafana.com>
This commit is contained in:
+50
-11
@@ -1,22 +1,61 @@
|
||||
//go:build !arm
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
|
||||
sqle "github.com/dolthub/go-mysql-server"
|
||||
mysql "github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql/analyzer"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
}
|
||||
// DB is a database that can execute SQL queries against a set of Frames.
|
||||
type DB struct{}
|
||||
|
||||
func (db *DB) RunCommands(commands []string) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
// QueryFrames runs the sql query query against a database created from frames, and returns the frame.
|
||||
// The RefID of each frame becomes a table in the database.
|
||||
// It is expected that there is only one frame per RefID.
|
||||
// The name becomes the name and RefID of the returned frame.
|
||||
func (db *DB) QueryFrames(ctx context.Context, name string, query string, frames []*data.Frame) (*data.Frame, error) {
|
||||
// We are parsing twice due to TablesList, but don't care fow now. We can save the parsed query and reuse it later if we want.
|
||||
if allow, err := AllowQuery(query); err != nil || !allow {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (db *DB) QueryFramesInto(name string, query string, frames []*data.Frame, f *data.Frame) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
pro := NewFramesDBProvider(frames)
|
||||
session := mysql.NewBaseSession()
|
||||
mCtx := mysql.NewContext(ctx, mysql.WithSession(session))
|
||||
|
||||
func NewInMemoryDB() *DB {
|
||||
return &DB{}
|
||||
// Select the database in the context
|
||||
mCtx.SetCurrentDatabase(dbName)
|
||||
|
||||
// Empty dir does not disable secure_file_priv
|
||||
//ctx.SetSessionVariable(ctx, "secure_file_priv", "")
|
||||
|
||||
// TODO: Check if it's wise to reuse the existing provider, rather than creating a new one
|
||||
a := analyzer.NewDefault(pro)
|
||||
|
||||
engine := sqle.New(a, &sqle.Config{
|
||||
IsReadOnly: true,
|
||||
})
|
||||
|
||||
schema, iter, _, err := engine.Query(mCtx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := convertToDataFrame(mCtx, iter, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.Name = name
|
||||
f.RefID = name
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
//go:build !arm
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQueryFrames(t *testing.T) {
|
||||
db := DB{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
input_frames []*data.Frame
|
||||
expected *data.Frame
|
||||
}{
|
||||
{
|
||||
name: "valid query with no input frames, one row one column",
|
||||
query: `SELECT '1' AS 'n';`,
|
||||
input_frames: []*data.Frame{},
|
||||
expected: data.NewFrame(
|
||||
"sqlExpressionRefId",
|
||||
data.NewField("n", nil, []string{"1"}),
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "valid query with no input frames, one row two columns",
|
||||
query: `SELECT 'sam' AS 'name', 40 AS 'age';`,
|
||||
input_frames: []*data.Frame{},
|
||||
expected: data.NewFrame(
|
||||
"sqlExpressionRefId",
|
||||
data.NewField("name", nil, []string{"sam"}),
|
||||
data.NewField("age", nil, []int8{40}),
|
||||
),
|
||||
},
|
||||
{
|
||||
// TODO: Also ORDER BY to ensure the order is preserved
|
||||
name: "query all rows from single input frame",
|
||||
query: `SELECT * FROM inputFrameRefId LIMIT 1;`,
|
||||
input_frames: []*data.Frame{
|
||||
setRefID(data.NewFrame(
|
||||
"",
|
||||
//nolint:misspell
|
||||
data.NewField("OSS Projects with Typos", nil, []string{"Garfana", "Pormetheus"}),
|
||||
), "inputFrameRefId"),
|
||||
},
|
||||
expected: data.NewFrame(
|
||||
"sqlExpressionRefId",
|
||||
data.NewField("OSS Projects with Typos", nil, []string{"Garfana"}),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
frame, err := db.QueryFrames(context.Background(), "sqlExpressionRefId", tt.query, tt.input_frames)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, frame.Fields)
|
||||
|
||||
require.Equal(t, tt.expected.Name, frame.RefID)
|
||||
require.Equal(t, len(tt.expected.Fields), len(frame.Fields))
|
||||
for i := range tt.expected.Fields {
|
||||
require.Equal(t, tt.expected.Fields[i].Name, frame.Fields[i].Name)
|
||||
require.Equal(t, tt.expected.Fields[i].At(0), frame.Fields[i].At(0))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryFramesInOut(t *testing.T) {
|
||||
frameA := &data.Frame{
|
||||
RefID: "a",
|
||||
Name: "a",
|
||||
Fields: []*data.Field{
|
||||
data.NewField("time", nil, []time.Time{time.Now(), time.Now()}),
|
||||
data.NewField("time_nullable", nil, []*time.Time{p(time.Now()), nil}),
|
||||
|
||||
data.NewField("string", nil, []string{"cat", "dog"}),
|
||||
data.NewField("null_nullable", nil, []*string{p("cat"), nil}),
|
||||
|
||||
data.NewField("float64", nil, []float64{1, 3}),
|
||||
data.NewField("float64_nullable", nil, []*float64{p(2.0), nil}),
|
||||
|
||||
data.NewField("int64", nil, []int64{1, 3}),
|
||||
data.NewField("int64_nullable", nil, []*int64{p(int64(2)), nil}),
|
||||
|
||||
data.NewField("bool", nil, []bool{true, false}),
|
||||
data.NewField("bool_nullable", nil, []*bool{p(true), nil}),
|
||||
},
|
||||
}
|
||||
|
||||
db := DB{}
|
||||
qry := `SELECT * from a`
|
||||
|
||||
resultFrame, err := db.QueryFrames(context.Background(), "a", qry, []*data.Frame{frameA})
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(frameA, resultFrame, data.FrameTestCompareOptions()...); diff != "" {
|
||||
require.FailNowf(t, "Result mismatch (-want +got):%s\n", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryFramesNumericSelect(t *testing.T) {
|
||||
expectedFrame := &data.Frame{
|
||||
RefID: "a",
|
||||
Name: "a",
|
||||
Fields: []*data.Field{
|
||||
data.NewField("decimal", nil, []float64{2.35}),
|
||||
data.NewField("tinySigned", nil, []int8{-128}),
|
||||
data.NewField("smallSigned", nil, []int16{-32768}),
|
||||
data.NewField("mediumSigned", nil, []int32{-8388608}),
|
||||
data.NewField("intSigned", nil, []int32{-2147483648}),
|
||||
data.NewField("bigSigned", nil, []int64{-9223372036854775808}),
|
||||
data.NewField("tinyUnsigned", nil, []uint8{255}),
|
||||
data.NewField("smallUnsigned", nil, []uint16{65535}),
|
||||
data.NewField("mediumUnsigned", nil, []int32{16777215}),
|
||||
data.NewField("intUnsigned", nil, []uint32{4294967295}),
|
||||
data.NewField("bigUnsigned", nil, []uint64{18446744073709551615}),
|
||||
},
|
||||
}
|
||||
|
||||
db := DB{}
|
||||
qry := `SELECT 2.35 AS 'decimal',
|
||||
-128 AS 'tinySigned',
|
||||
-32768 AS 'smallSigned',
|
||||
-8388608 AS 'mediumSigned',
|
||||
-2147483648 AS 'intSigned',
|
||||
-9223372036854775808 AS 'bigSigned',
|
||||
255 AS 'tinyUnsigned',
|
||||
65535 AS 'smallUnsigned',
|
||||
16777215 AS 'mediumUnsigned',
|
||||
4294967295 AS 'intUnsigned',
|
||||
18446744073709551615 AS 'bigUnsigned'`
|
||||
|
||||
resultFrame, err := db.QueryFrames(context.Background(), "a", qry, []*data.Frame{})
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(expectedFrame, resultFrame, data.FrameTestCompareOptions()...); diff != "" {
|
||||
require.FailNowf(t, "Result mismatch (-want +got):%s\n", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryFramesDateTimeSelect(t *testing.T) {
|
||||
t.Skip("need a fix in go-mysql-server, and then handle the datetime strings (or figure out why strings and not time.Time)")
|
||||
expectedFrame := &data.Frame{
|
||||
RefID: "a",
|
||||
Name: "a",
|
||||
Fields: []*data.Field{
|
||||
data.NewField("ts", nil, []time.Time{}),
|
||||
},
|
||||
}
|
||||
|
||||
db := DB{}
|
||||
|
||||
// It doesn't like the T in the time string
|
||||
qry := `SELECT str_to_date('2025-02-03T03:00:00','%Y-%m-%dT%H:%i:%s') as ts`
|
||||
|
||||
// This comes back as a string, which needs to be dealt with?
|
||||
//qry := `SELECT str_to_date('2025-02-03-03:00:00','%Y-%m-%d-%H:%i:%s') as ts`
|
||||
|
||||
// This is a datetime(6), need to deal with that as well
|
||||
//qry := `SELECT current_timestamp() as ts`
|
||||
|
||||
f, err := db.QueryFrames(context.Background(), "b", qry, []*data.Frame{})
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(expectedFrame, f, data.FrameTestCompareOptions()...); diff != "" {
|
||||
require.FailNowf(t, "Result mismatch (-want +got):%s\n", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// p is a utility for pointers from constants
|
||||
func p[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func setRefID(f *data.Frame, refID string) *data.Frame {
|
||||
f.RefID = refID
|
||||
return f
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build arm
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
)
|
||||
|
||||
type DB struct{}
|
||||
|
||||
// Stub out the QueryFrames method for ARM builds
|
||||
// See github.com/dolthub/go-mysql-server/issues/2837
|
||||
func (db *DB) QueryFrames(_ context.Context, _, _ string, _ []*data.Frame) (*data.Frame, error) {
|
||||
return nil, fmt.Errorf("sql expressions not supported in arm")
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
//go:build !arm
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
mysql "github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
)
|
||||
|
||||
var dbName = "frames"
|
||||
|
||||
// FramesDBProvider is a go-mysql-server DatabaseProvider that provides access to a set of Frames.
|
||||
type FramesDBProvider struct {
|
||||
db mysql.Database
|
||||
}
|
||||
|
||||
func (p *FramesDBProvider) Database(_ *mysql.Context, _ string) (mysql.Database, error) {
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
func (p *FramesDBProvider) HasDatabase(_ *mysql.Context, _ string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *FramesDBProvider) AllDatabases(_ *mysql.Context) []mysql.Database {
|
||||
return []mysql.Database{p.db}
|
||||
}
|
||||
|
||||
// NewFramesDBProvider creates a new FramesDBProvider with the given set of Frames.
|
||||
func NewFramesDBProvider(frames data.Frames) mysql.DatabaseProvider {
|
||||
fMap := make(map[string]mysql.Table, len(frames))
|
||||
for _, frame := range frames {
|
||||
fMap[frame.RefID] = &FrameTable{Frame: frame}
|
||||
}
|
||||
return &FramesDBProvider{
|
||||
db: &framesDB{
|
||||
frames: fMap,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// framesDB is a go-mysql-server Database that provides access to a set of Frames.
|
||||
type framesDB struct {
|
||||
frames map[string]mysql.Table
|
||||
}
|
||||
|
||||
func (db *framesDB) GetTableInsensitive(_ *mysql.Context, tblName string) (mysql.Table, bool, error) {
|
||||
tbl, ok := mysql.GetTableInsensitive(tblName, db.frames)
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
return tbl, ok, nil
|
||||
}
|
||||
|
||||
func (db *framesDB) GetTableNames(_ *mysql.Context) ([]string, error) {
|
||||
s := make([]string, 0, len(db.frames))
|
||||
for k := range db.frames {
|
||||
s = append(s, k)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (db *framesDB) Name() string {
|
||||
return dbName
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
//go:build !arm
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
mysql "github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
// TODO: Should this accept a row limit and converters, like sqlutil.FrameFromRows?
|
||||
func convertToDataFrame(ctx *mysql.Context, iter mysql.RowIter, schema mysql.Schema) (*data.Frame, error) {
|
||||
f := &data.Frame{}
|
||||
// Create fields based on the schema
|
||||
for _, col := range schema {
|
||||
fT, err := MySQLColToFieldType(col)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field := data.NewFieldFromFieldType(fT, 0)
|
||||
field.Name = col.Name
|
||||
f.Fields = append(f.Fields, field)
|
||||
}
|
||||
|
||||
// Iterate through the rows and append data to fields
|
||||
for {
|
||||
row, err := iter.Next(ctx)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading row: %v", err)
|
||||
}
|
||||
|
||||
for i, val := range row {
|
||||
v, err := fieldValFromRowVal(f.Fields[i].Type(), val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unexpected type for column %s: %w", schema[i].Name, err)
|
||||
}
|
||||
f.Fields[i].Append(v)
|
||||
}
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// MySQLColToFieldType converts a MySQL column to a data.FieldType
|
||||
func MySQLColToFieldType(col *mysql.Column) (data.FieldType, error) {
|
||||
var fT data.FieldType
|
||||
|
||||
switch col.Type {
|
||||
case types.Int8:
|
||||
fT = data.FieldTypeInt8
|
||||
case types.Uint8:
|
||||
fT = data.FieldTypeUint8
|
||||
case types.Int16:
|
||||
fT = data.FieldTypeInt16
|
||||
case types.Uint16:
|
||||
fT = data.FieldTypeUint16
|
||||
case types.Int32:
|
||||
fT = data.FieldTypeInt32
|
||||
case types.Uint32:
|
||||
fT = data.FieldTypeUint32
|
||||
case types.Int64:
|
||||
fT = data.FieldTypeInt64
|
||||
case types.Uint64:
|
||||
fT = data.FieldTypeUint64
|
||||
case types.Float64:
|
||||
fT = data.FieldTypeFloat64
|
||||
// StringType represents all string types, including VARCHAR and BLOB.
|
||||
case types.Text, types.LongText:
|
||||
fT = data.FieldTypeString
|
||||
case types.Timestamp:
|
||||
fT = data.FieldTypeTime
|
||||
case types.Datetime:
|
||||
fT = data.FieldTypeTime
|
||||
case types.Boolean:
|
||||
fT = data.FieldTypeBool
|
||||
default:
|
||||
if types.IsDecimal(col.Type) {
|
||||
fT = data.FieldTypeFloat64
|
||||
} else {
|
||||
return fT, fmt.Errorf("unsupported type for column %s of type %v", col.Name, col.Type)
|
||||
}
|
||||
}
|
||||
|
||||
if col.Nullable {
|
||||
fT = fT.NullableType()
|
||||
}
|
||||
|
||||
return fT, nil
|
||||
}
|
||||
|
||||
// Helper function to convert data.FieldType to types.Type
|
||||
func convertDataType(fieldType data.FieldType) mysql.Type {
|
||||
switch fieldType {
|
||||
case data.FieldTypeInt8, data.FieldTypeNullableInt8:
|
||||
return types.Int8
|
||||
case data.FieldTypeUint8, data.FieldTypeNullableUint8:
|
||||
return types.Uint8
|
||||
case data.FieldTypeInt16, data.FieldTypeNullableInt16:
|
||||
return types.Int16
|
||||
case data.FieldTypeUint16, data.FieldTypeNullableUint16:
|
||||
return types.Uint16
|
||||
case data.FieldTypeInt32, data.FieldTypeNullableInt32:
|
||||
return types.Int32
|
||||
case data.FieldTypeUint32, data.FieldTypeNullableUint32:
|
||||
return types.Uint32
|
||||
case data.FieldTypeInt64, data.FieldTypeNullableInt64:
|
||||
return types.Int64
|
||||
case data.FieldTypeUint64, data.FieldTypeNullableUint64:
|
||||
return types.Uint64
|
||||
case data.FieldTypeFloat32, data.FieldTypeNullableFloat32:
|
||||
return types.Float32
|
||||
case data.FieldTypeFloat64, data.FieldTypeNullableFloat64:
|
||||
return types.Float64
|
||||
case data.FieldTypeString, data.FieldTypeNullableString:
|
||||
return types.Text
|
||||
case data.FieldTypeBool, data.FieldTypeNullableBool:
|
||||
return types.Boolean
|
||||
case data.FieldTypeTime, data.FieldTypeNullableTime:
|
||||
return types.Timestamp
|
||||
default:
|
||||
fmt.Printf("------- Unsupported field type: %v", fieldType)
|
||||
return types.JSON
|
||||
}
|
||||
}
|
||||
|
||||
// fieldValFromRowVal converts a go-mysql-server row value to a data.field value
|
||||
//
|
||||
//nolint:gocyclo
|
||||
func fieldValFromRowVal(fieldType data.FieldType, val interface{}) (interface{}, error) {
|
||||
// the input val may be nil, it also may not be a pointer even if the fieldtype is a nullable pointer type
|
||||
if val == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
// ----------------------------
|
||||
// Int8 / Nullable Int8
|
||||
// ----------------------------
|
||||
case data.FieldTypeInt8:
|
||||
v, ok := val.(int8)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int8", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableInt8:
|
||||
vP, ok := val.(*int8)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(int8)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int8 or *int8", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Uint8 / Nullable Uint8
|
||||
// ----------------------------
|
||||
case data.FieldTypeUint8:
|
||||
v, ok := val.(uint8)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint8", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableUint8:
|
||||
vP, ok := val.(*uint8)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(uint8)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint8 or *uint8", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Int16 / Nullable Int16
|
||||
// ----------------------------
|
||||
case data.FieldTypeInt16:
|
||||
v, ok := val.(int16)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int16", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableInt16:
|
||||
vP, ok := val.(*int16)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(int16)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int16 or *int16", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Uint16 / Nullable Uint16
|
||||
// ----------------------------
|
||||
case data.FieldTypeUint16:
|
||||
v, ok := val.(uint16)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint16", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableUint16:
|
||||
vP, ok := val.(*uint16)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(uint16)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint16 or *uint16", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Int32 / Nullable Int32
|
||||
// ----------------------------
|
||||
case data.FieldTypeInt32:
|
||||
v, ok := val.(int32)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int32", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableInt32:
|
||||
vP, ok := val.(*int32)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(int32)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int32 or *int32", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Uint32 / Nullable Uint32
|
||||
// ----------------------------
|
||||
case data.FieldTypeUint32:
|
||||
v, ok := val.(uint32)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint32", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableUint32:
|
||||
vP, ok := val.(*uint32)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(uint32)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint32 or *uint32", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Int64 / Nullable Int64
|
||||
// ----------------------------
|
||||
case data.FieldTypeInt64:
|
||||
v, ok := val.(int64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int64", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableInt64:
|
||||
vP, ok := val.(*int64)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(int64)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected int64 or *int64", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Uint64 / Nullable Uint64
|
||||
// ----------------------------
|
||||
case data.FieldTypeUint64:
|
||||
v, ok := val.(uint64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint64", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableUint64:
|
||||
vP, ok := val.(*uint64)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(uint64)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected uint64 or *uint64", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Float64 / Nullable Float64
|
||||
// ----------------------------
|
||||
case data.FieldTypeFloat64:
|
||||
// Accept float64 or decimal.Decimal, convert decimal.Decimal -> float64
|
||||
if v, ok := val.(float64); ok {
|
||||
return v, nil
|
||||
}
|
||||
if d, ok := val.(decimal.Decimal); ok {
|
||||
return d.InexactFloat64(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected float64 or decimal.Decimal", val, val)
|
||||
|
||||
case data.FieldTypeNullableFloat64:
|
||||
// Possibly already *float64
|
||||
if vP, ok := val.(*float64); ok {
|
||||
return vP, nil
|
||||
}
|
||||
// Possibly float64
|
||||
if v, ok := val.(float64); ok {
|
||||
return &v, nil
|
||||
}
|
||||
// Possibly decimal.Decimal
|
||||
if d, ok := val.(decimal.Decimal); ok {
|
||||
f := d.InexactFloat64()
|
||||
return &f, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected float64, *float64, or decimal.Decimal", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Time / Nullable Time
|
||||
// ----------------------------
|
||||
case data.FieldTypeTime:
|
||||
v, ok := val.(time.Time)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected time.Time", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableTime:
|
||||
vP, ok := val.(*time.Time)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(time.Time)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected time.Time or *time.Time", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// String / Nullable String
|
||||
// ----------------------------
|
||||
case data.FieldTypeString:
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected string", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableString:
|
||||
vP, ok := val.(*string)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(string)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected string or *string", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Bool / Nullable Bool
|
||||
// ----------------------------
|
||||
case data.FieldTypeBool:
|
||||
v, ok := val.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected bool", val, val)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case data.FieldTypeNullableBool:
|
||||
vP, ok := val.(*bool)
|
||||
if ok {
|
||||
return vP, nil
|
||||
}
|
||||
v, ok := val.(bool)
|
||||
if ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected value type for interface %v of type %T, expected bool or *bool", val, val)
|
||||
|
||||
// ----------------------------
|
||||
// Fallback / Unsupported
|
||||
// ----------------------------
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported field type %s for val %v", fieldType, val)
|
||||
}
|
||||
}
|
||||
|
||||
// Is the field nilAt the index. Can panic if out of range.
|
||||
// TODO: Maybe this should be a method on data.Field?
|
||||
func nilAt(field data.Field, at int) bool {
|
||||
if !field.Nullable() {
|
||||
return false
|
||||
}
|
||||
|
||||
switch field.Type() {
|
||||
case data.FieldTypeNullableInt8:
|
||||
v := field.At(at).(*int8)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableUint8:
|
||||
v := field.At(at).(*uint8)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableInt16:
|
||||
v := field.At(at).(*int16)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableUint16:
|
||||
v := field.At(at).(*uint16)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableInt32:
|
||||
v := field.At(at).(*int32)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableUint32:
|
||||
v := field.At(at).(*uint32)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableInt64:
|
||||
v := field.At(at).(*int64)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableUint64:
|
||||
v := field.At(at).(*uint64)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableFloat64:
|
||||
v := field.At(at).(*float64)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableString:
|
||||
v := field.At(at).(*string)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableTime:
|
||||
v := field.At(at).(*time.Time)
|
||||
return v == nil
|
||||
|
||||
case data.FieldTypeNullableBool:
|
||||
v := field.At(at).(*bool)
|
||||
return v == nil
|
||||
|
||||
default:
|
||||
// Either it's not a nullable type or it's unsupported
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
//go:build !arm
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
mysql "github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/grafana/grafana-plugin-sdk-go/data"
|
||||
)
|
||||
|
||||
// FrameTable fulfills the mysql.Table interface for a data.Frame.
|
||||
type FrameTable struct {
|
||||
Frame *data.Frame
|
||||
schema mysql.Schema
|
||||
}
|
||||
|
||||
// Name implements the sql.Nameable interface
|
||||
func (ft *FrameTable) Name() string {
|
||||
return ft.Frame.RefID
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface
|
||||
func (ft *FrameTable) String() string {
|
||||
return ft.Name()
|
||||
}
|
||||
|
||||
func schemaFromFrame(frame *data.Frame) mysql.Schema {
|
||||
schema := make(mysql.Schema, len(frame.Fields))
|
||||
|
||||
for i, field := range frame.Fields {
|
||||
schema[i] = &mysql.Column{
|
||||
Name: field.Name,
|
||||
Type: convertDataType(field.Type()),
|
||||
Nullable: field.Type().Nullable(),
|
||||
Source: strings.ToLower(frame.RefID),
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
// Schema implements the mysql.Table interface
|
||||
func (ft *FrameTable) Schema() mysql.Schema {
|
||||
if ft.schema == nil {
|
||||
ft.schema = schemaFromFrame(ft.Frame)
|
||||
}
|
||||
return ft.schema
|
||||
}
|
||||
|
||||
// Collation implements the mysql.Table interface
|
||||
func (ft *FrameTable) Collation() mysql.CollationID {
|
||||
return mysql.Collation_Unspecified
|
||||
}
|
||||
|
||||
// Partitions implements the mysql.Table interface
|
||||
func (ft *FrameTable) Partitions(ctx *mysql.Context) (mysql.PartitionIter, error) {
|
||||
return &noopPartitionIter{}, nil
|
||||
}
|
||||
|
||||
// PartitionRows implements the mysql.Table interface
|
||||
func (ft *FrameTable) PartitionRows(ctx *mysql.Context, _ mysql.Partition) (mysql.RowIter, error) {
|
||||
return &rowIter{ft: ft, row: 0}, nil
|
||||
}
|
||||
|
||||
type rowIter struct {
|
||||
ft *FrameTable
|
||||
row int
|
||||
}
|
||||
|
||||
func (ri *rowIter) Next(_ *mysql.Context) (mysql.Row, error) {
|
||||
// We assume each field in the Frame has the same number of rows.
|
||||
numRows := 0
|
||||
if len(ri.ft.Frame.Fields) > 0 {
|
||||
numRows = ri.ft.Frame.Fields[0].Len()
|
||||
}
|
||||
|
||||
// If we've already exhausted all rows, return EOF
|
||||
if ri.row >= numRows {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// Construct a Row (which is []interface{} under the hood) by pulling
|
||||
// the value from each column at the current row index.
|
||||
row := make(mysql.Row, len(ri.ft.Frame.Fields))
|
||||
for colIndex, field := range ri.ft.Frame.Fields {
|
||||
if nilAt(*field, ri.row) {
|
||||
continue
|
||||
}
|
||||
row[colIndex], _ = field.ConcreteAt(ri.row)
|
||||
}
|
||||
|
||||
ri.row++
|
||||
return row, nil
|
||||
}
|
||||
|
||||
// Close implements the mysql.RowIter interface.
|
||||
// In this no-op example, there isn't anything to do here.
|
||||
func (ri *rowIter) Close(*mysql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopPartitionIter struct {
|
||||
done bool
|
||||
}
|
||||
|
||||
func (i *noopPartitionIter) Next(*mysql.Context) (mysql.Partition, error) {
|
||||
if !i.done {
|
||||
i.done = true
|
||||
return noopParition, nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
func (i *noopPartitionIter) Close(*mysql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var noopParition = partition(nil)
|
||||
|
||||
type partition []byte
|
||||
|
||||
func (p partition) Key() []byte {
|
||||
return p
|
||||
}
|
||||
+37
-64
@@ -1,89 +1,62 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/vitess/go/vt/sqlparser"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/jeremywohl/flatten"
|
||||
)
|
||||
|
||||
const (
|
||||
TABLE_NAME = "table_name"
|
||||
ERROR = ".error"
|
||||
ERROR_MESSAGE = ".error_message"
|
||||
)
|
||||
|
||||
var logger = log.New("sql_expr")
|
||||
|
||||
// TablesList returns a list of tables for the sql statement
|
||||
func TablesList(rawSQL string) ([]string, error) {
|
||||
db := NewInMemoryDB()
|
||||
rawSQL = strings.Replace(rawSQL, "'", "''", -1)
|
||||
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL)
|
||||
ret, err := db.RunCommands([]string{cmd})
|
||||
stmt, err := sqlparser.Parse(rawSQL)
|
||||
if err != nil {
|
||||
logger.Error("error serializing sql", "error", err.Error(), "sql", rawSQL, "cmd", cmd)
|
||||
return nil, fmt.Errorf("error serializing sql: %s", err.Error())
|
||||
logger.Error("error parsing sql: %s", err.Error(), "sql", rawSQL)
|
||||
return nil, fmt.Errorf("error parsing sql: %s", err.Error())
|
||||
}
|
||||
|
||||
ast := []map[string]any{}
|
||||
err = json.Unmarshal([]byte(ret), &ast)
|
||||
if err != nil {
|
||||
logger.Error("error converting json sql to ast", "error", err.Error(), "ret", ret)
|
||||
return nil, fmt.Errorf("error converting json to ast: %s", err.Error())
|
||||
}
|
||||
tables := make(map[string]struct{})
|
||||
|
||||
return tablesFromAST(ast)
|
||||
}
|
||||
|
||||
// tablesFromAST returns a list of tables from the ast
|
||||
func tablesFromAST(ast []map[string]any) ([]string, error) {
|
||||
flat, err := flatten.Flatten(ast[0], "", flatten.DotStyle)
|
||||
if err != nil {
|
||||
logger.Error("error flattening ast", "error", err.Error(), "ast", ast)
|
||||
return nil, fmt.Errorf("error flattening ast: %s", err.Error())
|
||||
}
|
||||
|
||||
tables := []string{}
|
||||
for k, v := range flat {
|
||||
if strings.HasSuffix(k, ERROR) {
|
||||
v, ok := v.(bool)
|
||||
if ok && v {
|
||||
logger.Error("error in sql", "error", k)
|
||||
return nil, astError(k, flat)
|
||||
walkSubtree := func(node sqlparser.SQLNode) error {
|
||||
err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
|
||||
switch v := node.(type) {
|
||||
case *sqlparser.AliasedTableExpr:
|
||||
if tableName, ok := v.Expr.(sqlparser.TableName); ok {
|
||||
tables[tableName.Name.String()] = struct{}{}
|
||||
}
|
||||
case *sqlparser.TableName:
|
||||
tables[v.Name.String()] = struct{}{}
|
||||
}
|
||||
return true, nil
|
||||
}, node)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("error walking sql", "error", err, "node", node)
|
||||
return fmt.Errorf("failed to parse SQL expression: %w", err)
|
||||
}
|
||||
if strings.Contains(k, TABLE_NAME) {
|
||||
table, ok := v.(string)
|
||||
if ok && !existsInList(table, tables) {
|
||||
tables = append(tables, v.(string))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := walkSubtree(stmt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(tables))
|
||||
for table := range tables {
|
||||
// Remove 'dual' table if it exists
|
||||
// This is a special table in MySQL that always returns a single row with a single column
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/select.html#:~:text=You%20are%20permitted%20to%20specify%20DUAL%20as%20a%20dummy%20table%20name%20in%20situations%20where%20no%20tables%20are%20referenced
|
||||
if table != "dual" {
|
||||
result = append(result, table)
|
||||
}
|
||||
}
|
||||
sort.Strings(tables)
|
||||
|
||||
sort.Strings(result)
|
||||
|
||||
logger.Debug("tables found in sql", "tables", tables)
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func astError(k string, flat map[string]any) error {
|
||||
key := strings.Replace(k, ERROR, "", 1)
|
||||
message, ok := flat[key+ERROR_MESSAGE]
|
||||
if !ok {
|
||||
message = "unknown error in sql"
|
||||
}
|
||||
return fmt.Errorf("error in sql: %s", message)
|
||||
}
|
||||
|
||||
func existsInList(table string, list []string) bool {
|
||||
for _, t := range list {
|
||||
if t == table {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
// AllowQuery parses the query and checks it against an allow list of allowed SQL nodes
|
||||
// and functions.
|
||||
func AllowQuery(rawSQL string) (bool, error) {
|
||||
s, err := sqlparser.Parse(rawSQL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error parsing sql: %s", err.Error())
|
||||
}
|
||||
|
||||
walkSubtree := func(node sqlparser.SQLNode) error {
|
||||
err := sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
|
||||
if !allowedNode(node) {
|
||||
if fT, ok := node.(*sqlparser.FuncExpr); ok {
|
||||
return false, fmt.Errorf("blocked function %s - not supported in queries", fT.Name)
|
||||
}
|
||||
return false, fmt.Errorf("blocked node %T - not supported in queries", node)
|
||||
}
|
||||
return true, nil
|
||||
}, node)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse SQL expression: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := walkSubtree(s); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// nolint:gocyclo,nakedret
|
||||
func allowedNode(node sqlparser.SQLNode) (b bool) {
|
||||
b = true // so don't have to return true in every case but default
|
||||
|
||||
switch v := node.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
return allowedFunction(v)
|
||||
|
||||
case *sqlparser.AsOf:
|
||||
return
|
||||
|
||||
case *sqlparser.AliasedExpr, *sqlparser.AliasedTableExpr:
|
||||
return
|
||||
|
||||
case *sqlparser.BinaryExpr:
|
||||
return
|
||||
|
||||
case sqlparser.ColIdent, *sqlparser.ColName, sqlparser.Columns:
|
||||
return
|
||||
|
||||
case sqlparser.Comments: // TODO: understand why some are pointer vs not
|
||||
return
|
||||
|
||||
case *sqlparser.CommonTableExpr:
|
||||
return
|
||||
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return
|
||||
|
||||
case *sqlparser.ConvertExpr:
|
||||
return
|
||||
|
||||
case sqlparser.GroupBy:
|
||||
return
|
||||
|
||||
case *sqlparser.IndexHints:
|
||||
return
|
||||
|
||||
case *sqlparser.Into:
|
||||
return
|
||||
|
||||
case *sqlparser.JoinTableExpr, sqlparser.JoinCondition:
|
||||
return
|
||||
|
||||
case *sqlparser.Select, sqlparser.SelectExprs:
|
||||
return
|
||||
|
||||
case *sqlparser.StarExpr:
|
||||
return
|
||||
|
||||
case *sqlparser.SQLVal:
|
||||
return
|
||||
|
||||
case *sqlparser.Limit:
|
||||
return
|
||||
|
||||
case *sqlparser.Order, sqlparser.OrderBy:
|
||||
return
|
||||
|
||||
case *sqlparser.Over:
|
||||
return
|
||||
|
||||
case *sqlparser.Subquery:
|
||||
return
|
||||
|
||||
case sqlparser.TableName, sqlparser.TableExprs, sqlparser.TableIdent:
|
||||
return
|
||||
|
||||
case *sqlparser.With:
|
||||
return
|
||||
|
||||
case *sqlparser.Where:
|
||||
return
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:gocyclo,nakedret
|
||||
func allowedFunction(f *sqlparser.FuncExpr) (b bool) {
|
||||
b = true // so don't have to return true in every case but default
|
||||
|
||||
switch strings.ToLower(f.Name.String()) {
|
||||
case "sum", "avg", "count", "min", "max":
|
||||
return
|
||||
|
||||
case "coalesce":
|
||||
return
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowQuery(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
q string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "a big catch all for now",
|
||||
q: example_metrics_query,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := AllowQuery(tc.q)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var example_metrics_query = `WITH
|
||||
metrics_this_month AS (
|
||||
SELECT
|
||||
Month,
|
||||
namespace,
|
||||
sum(BillableSeries) AS billable_series
|
||||
FROM metrics
|
||||
WHERE
|
||||
Month = "2024-11"
|
||||
GROUP BY
|
||||
Month,
|
||||
namespace
|
||||
ORDER BY billable_series DESC
|
||||
),
|
||||
total_metrics AS (
|
||||
SELECT SUM(billable_series) AS metrics_billable_series_total
|
||||
FROM metrics_this_month
|
||||
),
|
||||
total_traces AS (
|
||||
-- "usage" is a reserved keyword in MySQL. Quote it with backticks.
|
||||
SELECT SUM(value) AS traces_usage_total
|
||||
FROM traces
|
||||
),
|
||||
usage_by_team AS (
|
||||
SELECT
|
||||
COALESCE(teams.team, 'unaccounted') AS team,
|
||||
1 + 0 AS team_count,
|
||||
-- Metrics
|
||||
SUM(COALESCE(metrics_this_month.billable_series, 0)) AS metrics_billable_series,
|
||||
-- Traces
|
||||
SUM(COALESCE(traces.value, 0)) AS traces_usage
|
||||
-- FROM teams
|
||||
-- FULL OUTER JOIN metrics_this_month
|
||||
FROM metrics_this_month
|
||||
FULL OUTER JOIN teams
|
||||
ON teams.namespace = metrics_this_month.namespace
|
||||
FULL OUTER JOIN traces
|
||||
ON teams.namespace = traces.namespace
|
||||
GROUP BY
|
||||
-- COALESCE(teams.team, 'unaccounted')
|
||||
teams.team
|
||||
ORDER BY metrics_billable_series DESC
|
||||
)
|
||||
|
||||
SELECT *
|
||||
FROM usage_by_team
|
||||
CROSS JOIN total_metrics
|
||||
CROSS JOIN total_traces`
|
||||
+124
-207
@@ -3,214 +3,131 @@ package sql
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "select * from foo"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
func TestTablesList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
expected []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
sql: "select * from foo",
|
||||
expected: []string{"foo"},
|
||||
},
|
||||
{
|
||||
name: "select with comma",
|
||||
sql: "select * from foo,bar",
|
||||
expected: []string{"bar", "foo"},
|
||||
},
|
||||
{
|
||||
name: "select with multiple commas",
|
||||
sql: "select * from foo,bar,baz",
|
||||
expected: []string{"bar", "baz", "foo"},
|
||||
},
|
||||
{
|
||||
name: "no table",
|
||||
sql: "select 1 as 'n'",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "json array",
|
||||
sql: "SELECT JSON_ARRAY(1, 2, 3) AS array_value",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "json extract",
|
||||
sql: "SELECT JSON_EXTRACT(JSON_ARRAY(1, 2, 3), '$[0]') AS first_element;",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "json int array",
|
||||
sql: "SELECT JSON_ARRAY(3, 2, 1) AS int_array;",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "subquery",
|
||||
sql: "select * from (select * from people limit 1) AS subquery",
|
||||
expected: []string{"people"},
|
||||
},
|
||||
{
|
||||
name: "join",
|
||||
sql: `select * from A
|
||||
JOIN B ON A.name = B.name
|
||||
LIMIT 10`,
|
||||
expected: []string{"A", "B"},
|
||||
},
|
||||
{
|
||||
name: "right join",
|
||||
sql: `select * from A
|
||||
RIGHT JOIN B ON A.name = B.name
|
||||
LIMIT 10`,
|
||||
expected: []string{"A", "B"},
|
||||
},
|
||||
{
|
||||
name: "alias with join",
|
||||
sql: `select * from A as X
|
||||
RIGHT JOIN B ON A.name = X.name
|
||||
LIMIT 10`,
|
||||
expected: []string{"A", "B"},
|
||||
},
|
||||
{
|
||||
name: "alias",
|
||||
sql: "select * from A as X LIMIT 10",
|
||||
expected: []string{"A"},
|
||||
},
|
||||
{
|
||||
name: "error case",
|
||||
sql: "select * from zzz aaa zzz",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "parens",
|
||||
sql: `SELECT t1.Col1,
|
||||
t2.Col1,
|
||||
t3.Col1
|
||||
FROM table1 AS t1
|
||||
LEFT JOIN (
|
||||
table2 AS t2
|
||||
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
|
||||
) ON t2.Col1 = t1.Col1;`,
|
||||
expected: []string{"table1", "table2", "table3"},
|
||||
},
|
||||
{
|
||||
name: "with clause",
|
||||
sql: `WITH top_products AS (
|
||||
SELECT * FROM products
|
||||
ORDER BY price DESC
|
||||
LIMIT 5
|
||||
)
|
||||
SELECT name, price
|
||||
FROM top_products;`,
|
||||
expected: []string{"products", "top_products"},
|
||||
},
|
||||
{
|
||||
name: "with quote",
|
||||
sql: "select *,'junk' from foo",
|
||||
expected: []string{"foo"},
|
||||
},
|
||||
{
|
||||
name: "with quote 2",
|
||||
sql: "SELECT json_serialize_sql('SELECT 1')",
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "foo", tables[0])
|
||||
}
|
||||
|
||||
func TestParseWithComma(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "select * from foo,bar"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "bar", tables[0])
|
||||
assert.Equal(t, "foo", tables[1])
|
||||
}
|
||||
|
||||
func TestParseWithCommas(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "select * from foo,bar,baz"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "bar", tables[0])
|
||||
assert.Equal(t, "baz", tables[1])
|
||||
assert.Equal(t, "foo", tables[2])
|
||||
}
|
||||
|
||||
func TestArray(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "SELECT array_value(1, 2, 3)"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 0, len(tables))
|
||||
}
|
||||
|
||||
func TestArray2(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "SELECT array_value(1, 2, 3)[2]"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 0, len(tables))
|
||||
}
|
||||
|
||||
func TestXxx(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "SELECT [3, 2, 1]::INT[3];"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 0, len(tables))
|
||||
}
|
||||
|
||||
func TestParseSubquery(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "select * from (select * from people limit 1)"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(tables))
|
||||
assert.Equal(t, "people", tables[0])
|
||||
}
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `select * from A
|
||||
JOIN B ON A.name = B.name
|
||||
LIMIT 10`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
assert.Equal(t, "B", tables[1])
|
||||
}
|
||||
|
||||
func TestRightJoin(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `select * from A
|
||||
RIGHT JOIN B ON A.name = B.name
|
||||
LIMIT 10`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
assert.Equal(t, "B", tables[1])
|
||||
}
|
||||
|
||||
func TestAliasWithJoin(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `select * from A as X
|
||||
RIGHT JOIN B ON A.name = X.name
|
||||
LIMIT 10`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
assert.Equal(t, "B", tables[1])
|
||||
}
|
||||
|
||||
func TestAlias(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `select * from A as X LIMIT 10`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `select * from zzz aaa zzz`
|
||||
_, err := TablesList((sql))
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestParens(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `SELECT t1.Col1,
|
||||
t2.Col1,
|
||||
t3.Col1
|
||||
FROM table1 AS t1
|
||||
LEFT JOIN (
|
||||
table2 AS t2
|
||||
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
|
||||
) ON t2.Col1 = t1.Col1;`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 3, len(tables))
|
||||
assert.Equal(t, "table1", tables[0])
|
||||
assert.Equal(t, "table2", tables[1])
|
||||
assert.Equal(t, "table3", tables[2])
|
||||
}
|
||||
|
||||
func TestWith(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := `WITH
|
||||
|
||||
current_month AS (
|
||||
select
|
||||
distinct "Month(ISO)" as mth
|
||||
from A
|
||||
ORDER BY mth DESC
|
||||
LIMIT 1
|
||||
),
|
||||
|
||||
last_month_bill AS (
|
||||
select
|
||||
CAST (
|
||||
sum(
|
||||
CAST(BillableSeries AS INTEGER)
|
||||
) AS INTEGER
|
||||
) AS BillableSeries,
|
||||
"Month(ISO)",
|
||||
label_namespace
|
||||
-- , B.activeseries_count
|
||||
from A
|
||||
JOIN current_month
|
||||
ON current_month.mth = A."Month(ISO)"
|
||||
JOIN B
|
||||
ON B.namespace = A.label_namespace
|
||||
GROUP BY
|
||||
label_namespace,
|
||||
"Month(ISO)"
|
||||
ORDER BY BillableSeries DESC
|
||||
)
|
||||
|
||||
SELECT
|
||||
last_month_bill.*,
|
||||
BEE.activeseries_count
|
||||
FROM last_month_bill
|
||||
JOIN BEE
|
||||
ON BEE.namespace = last_month_bill.label_namespace`
|
||||
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 5, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
assert.Equal(t, "B", tables[1])
|
||||
assert.Equal(t, "BEE", tables[2])
|
||||
}
|
||||
|
||||
func TestWithQuote(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "select *,'junk' from foo"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "foo", tables[0])
|
||||
}
|
||||
|
||||
func TestWithQuote2(t *testing.T) {
|
||||
t.Skip()
|
||||
sql := "SELECT json_serialize_sql('SELECT 1')"
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 0, len(tables))
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tables, err := TablesList(tc.sql)
|
||||
if tc.expectError {
|
||||
require.NotNil(t, err, "expected error for SQL: %s", tc.sql)
|
||||
} else {
|
||||
require.Nil(t, err, "unexpected error for SQL: %s", tc.sql)
|
||||
require.Equal(t, tc.expected, tables, "mismatched tables for SQL: %s", tc.sql)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user