Files
grafana/pkg/services/store/entity/sqlstash/querybuilder.go
Dan Cech 9c46e71d5a Storage: Support continue at specified resource version (#84868)
* support continue at specified resource version

* detect whether list continue pages need to use entity_history, remove BatchRead, expand selectQuery helper

* refactor continue token handling

* fix tests, increase history chunk size

* lint fix
2024-04-08 09:57:02 -04:00

222 lines
4.6 KiB
Go

package sqlstash
import (
"encoding/json"
"strings"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
)
type Direction int
const (
Ascending Direction = iota
Descending
)
func (d Direction) String() string {
if d == Descending {
return "DESC"
}
return "ASC"
}
type joinQuery struct {
query string
args []any
}
type whereClause struct {
query string
args []any
}
type selectQuery struct {
dialect migrator.Dialect
fields []string // SELECT xyz
from string // FROM object
joins []joinQuery // JOIN object
offset int64
limit int64
oneExtra bool
where []whereClause
groupBy []string
orderBy []string
direction []Direction
}
func NewSelectQuery(dialect migrator.Dialect, from string) *selectQuery {
return &selectQuery{
dialect: dialect,
from: from,
}
}
func (q *selectQuery) From(from string) {
q.from = from
}
func (q *selectQuery) SetLimit(limit int64) {
q.limit = limit
}
func (q *selectQuery) SetOffset(offset int64) {
q.offset = offset
}
func (q *selectQuery) SetOneExtra() {
q.oneExtra = true
}
func (q *selectQuery) UnsetOneExtra() {
q.oneExtra = false
}
func (q *selectQuery) AddFields(f ...string) {
for _, field := range f {
q.fields = append(q.fields, "t."+q.dialect.Quote(field))
}
}
func (q *selectQuery) AddRawFields(f ...string) {
q.fields = append(q.fields, f...)
}
func (q *selectQuery) AddJoin(j string, args ...any) {
q.joins = append(q.joins, joinQuery{query: j, args: args})
}
func (q *selectQuery) AddWhere(f string, val ...any) {
// if the field contains a question mark, we assume it's a raw where clause
if strings.Contains(f, "?") {
q.where = append(q.where, whereClause{f, val})
// otherwise we assume it's a field name
} else {
q.where = append(q.where, whereClause{"t." + q.dialect.Quote(f) + "=?", val})
}
}
func (q *selectQuery) AddWhereInSubquery(f string, subquery string, subqueryArgs []any) {
q.where = append(q.where, whereClause{"t." + q.dialect.Quote(f) + " IN (" + subquery + ")", subqueryArgs})
}
func (q *selectQuery) AddWhereIn(f string, vals []any) {
count := len(vals)
if count > 1 {
sb := strings.Builder{}
sb.WriteString("t." + q.dialect.Quote(f))
sb.WriteString(" IN (")
for i := 0; i < count; i++ {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString("?")
}
sb.WriteString(") ")
q.where = append(q.where, whereClause{sb.String(), vals})
} else if count == 1 {
q.AddWhere(f, vals[0])
}
}
func ToAnyList[T any](input []T) []any {
list := make([]any, len(input))
for i, v := range input {
list[i] = v
}
return list
}
const sqlLikeEscape = "#"
var sqlLikeEscapeReplacer = strings.NewReplacer(
sqlLikeEscape, sqlLikeEscape+sqlLikeEscape,
"%", sqlLikeEscape+"%",
"_", sqlLikeEscape+"_",
)
func escapeJSONStringSQLLike(s string) string {
b, _ := json.Marshal(s)
return sqlLikeEscapeReplacer.Replace(string(b))
}
func (q *selectQuery) AddWhereJsonContainsKV(field string, key string, value string) {
escapedKey := escapeJSONStringSQLLike(key)
escapedValue := escapeJSONStringSQLLike(value)
q.where = append(q.where, whereClause{
"t." + q.dialect.Quote(field) + " LIKE ? ESCAPE ?",
[]any{"{%\"" + escapedKey + "\":\"" + escapedValue + "\"%}", sqlLikeEscape},
})
}
func (q *selectQuery) AddGroupBy(f string) {
q.groupBy = append(q.groupBy, f)
}
func (q *selectQuery) AddOrderBy(field string, direction Direction) {
q.orderBy = append(q.orderBy, field)
q.direction = append(q.direction, direction)
}
func (q *selectQuery) ToQuery() (string, []any) {
args := []any{}
sb := strings.Builder{}
sb.WriteString("SELECT ")
sb.WriteString(strings.Join(q.fields, ","))
sb.WriteString(" FROM ")
sb.WriteString(q.from)
sb.WriteString(" AS t")
for _, j := range q.joins {
sb.WriteString(" " + j.query)
args = append(args, j.args...)
}
// Templated where string
if len(q.where) > 0 {
sb.WriteString(" WHERE ")
for i, w := range q.where {
if i > 0 {
sb.WriteString(" AND ")
}
sb.WriteString(w.query)
args = append(args, w.args...)
}
}
if len(q.groupBy) > 0 {
sb.WriteString(" GROUP BY ")
for i, f := range q.groupBy {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString("t." + q.dialect.Quote(f))
}
}
if len(q.orderBy) > 0 && len(q.direction) == len(q.orderBy) {
sb.WriteString(" ORDER BY ")
for i, f := range q.orderBy {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString("t." + q.dialect.Quote(f))
sb.WriteString(" ")
sb.WriteString(q.direction[i].String())
}
}
limit := q.limit
if limit > 0 {
if q.oneExtra {
limit = limit + 1
}
sb.WriteString(q.dialect.LimitOffset(limit, q.offset))
}
return sb.String(), args
}