152 lines
3.7 KiB
Go
152 lines
3.7 KiB
Go
package sqlstash
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"text/template"
|
|
|
|
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
|
"github.com/grafana/grafana/pkg/services/store/entity/db"
|
|
"github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate"
|
|
)
|
|
|
|
func createETag(body []byte, meta []byte, status []byte) string {
|
|
// TODO: can we change this to something more modern like sha256?
|
|
h := md5.New()
|
|
_, _ = h.Write(meta)
|
|
_, _ = h.Write(body)
|
|
_, _ = h.Write(status)
|
|
hash := h.Sum(nil)
|
|
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// getCurrentUser returns a string identifying the user making a request with
|
|
// the given context.
|
|
func getCurrentUser(ctx context.Context) (string, error) {
|
|
user, err := identity.GetRequester(ctx)
|
|
if err != nil || user == nil {
|
|
return "", fmt.Errorf("%w: %w", ErrUserNotFoundInContext, err)
|
|
}
|
|
|
|
return user.GetUID().String(), nil
|
|
}
|
|
|
|
// ptrOr returns the first non-nil pointer in the list or a new non-nil pointer.
|
|
func ptrOr[P ~*E, E any](ps ...P) P {
|
|
for _, p := range ps {
|
|
if p != nil {
|
|
return p
|
|
}
|
|
}
|
|
|
|
return P(new(E))
|
|
}
|
|
|
|
// sliceOr returns the first slice that has at least one element, or a new empty
|
|
// slice.
|
|
func sliceOr[S ~[]E, E comparable](vals ...S) S {
|
|
for _, s := range vals {
|
|
if len(s) > 0 {
|
|
return s
|
|
}
|
|
}
|
|
|
|
return S{}
|
|
}
|
|
|
|
// mapOr returns the first map that has at least one element, or a new empty
|
|
// map.
|
|
func mapOr[M ~map[K]V, K comparable, V any](vals ...M) M {
|
|
for _, m := range vals {
|
|
if len(m) > 0 {
|
|
return m
|
|
}
|
|
}
|
|
|
|
return M{}
|
|
}
|
|
|
|
// queryRow uses `req` as input and output for a single-row returning query
|
|
// generated with `tmpl`, and executed in `x`.
|
|
func queryRow[T any](ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.WithResults[T]) (T, error) {
|
|
var zero T
|
|
|
|
if err := req.Validate(); err != nil {
|
|
return zero, fmt.Errorf("query: invalid request for template %q: %w",
|
|
tmpl.Name(), err)
|
|
}
|
|
|
|
rawQuery, err := sqltemplate.Execute(tmpl, req)
|
|
if err != nil {
|
|
return zero, fmt.Errorf("execute template: %w", err)
|
|
}
|
|
query := sqltemplate.FormatSQL(rawQuery)
|
|
|
|
row := x.QueryRowContext(ctx, query, req.GetArgs()...)
|
|
if err := row.Err(); err != nil {
|
|
return zero, SQLError{
|
|
Err: err,
|
|
CallType: "QueryRow",
|
|
TemplateName: tmpl.Name(),
|
|
arguments: req.GetArgs(),
|
|
ScanDest: req.GetScanDest(),
|
|
Query: query,
|
|
RawQuery: rawQuery,
|
|
}
|
|
}
|
|
|
|
return scanRow(row, req)
|
|
}
|
|
|
|
type scanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
// scanRow is used on *sql.Row and *sql.Rows, and is factored out here not to
|
|
// improving code reuse, but rather for ease of testing.
|
|
func scanRow[T any](sc scanner, req sqltemplate.WithResults[T]) (zero T, err error) {
|
|
if err = sc.Scan(req.GetScanDest()...); err != nil {
|
|
return zero, fmt.Errorf("row scan: %w", err)
|
|
}
|
|
|
|
res, err := req.Results()
|
|
if err != nil {
|
|
return zero, fmt.Errorf("row results: %w", err)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// exec uses `req` as input for a non-data returning query generated with
|
|
// `tmpl`, and executed in `x`.
|
|
func exec(ctx context.Context, x db.ContextExecer, tmpl *template.Template, req sqltemplate.SQLTemplateIface) (sql.Result, error) {
|
|
if err := req.Validate(); err != nil {
|
|
return nil, fmt.Errorf("exec: invalid request for template %q: %w",
|
|
tmpl.Name(), err)
|
|
}
|
|
|
|
rawQuery, err := sqltemplate.Execute(tmpl, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("execute template: %w", err)
|
|
}
|
|
query := sqltemplate.FormatSQL(rawQuery)
|
|
|
|
res, err := x.ExecContext(ctx, query, req.GetArgs()...)
|
|
if err != nil {
|
|
return nil, SQLError{
|
|
Err: err,
|
|
CallType: "Exec",
|
|
TemplateName: tmpl.Name(),
|
|
arguments: req.GetArgs(),
|
|
Query: query,
|
|
RawQuery: rawQuery,
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|