diff --git a/pkg/infra/usagestats/service.go b/pkg/infra/usagestats/service.go index 6c82e1e714f..d4b5a574c29 100644 --- a/pkg/infra/usagestats/service.go +++ b/pkg/infra/usagestats/service.go @@ -2,6 +2,7 @@ package usagestats import ( "context" + "fmt" "time" "github.com/grafana/grafana/pkg/bus" @@ -22,7 +23,7 @@ func init() { } type UsageStats interface { - GetUsageReport() (UsageReport, error) + GetUsageReport(ctx context.Context) (UsageReport, error) RegisterMetric(name string, fn MetricFunc) } @@ -38,8 +39,9 @@ type UsageStatsService struct { log log.Logger - oauthProviders map[string]bool - externalMetrics map[string]MetricFunc + oauthProviders map[string]bool + externalMetrics map[string]MetricFunc + concurrentUserStatsCache memoConcurrentUserStats } func (uss *UsageStatsService) Init() error { @@ -60,7 +62,7 @@ func (uss *UsageStatsService) Run(ctx context.Context) error { for { select { case <-onceEveryDayTick.C: - if err := uss.sendUsageStats(); err != nil { + if err := uss.sendUsageStats(ctx); err != nil { metricsLogger.Warn("Failed to send usage stats", "err", err) } case <-everyMinuteTicker.C: @@ -70,3 +72,43 @@ func (uss *UsageStatsService) Run(ctx context.Context) error { } } } + +type memoConcurrentUserStats struct { + stats *concurrentUsersStats + + memoized time.Time +} + +const concurrentUserStatsCacheLifetime = time.Hour + +func (uss *UsageStatsService) GetConcurrentUsersStats(ctx context.Context) (*concurrentUsersStats, error) { + memoizationPeriod := time.Now().Add(-concurrentUserStatsCacheLifetime) + if !uss.concurrentUserStatsCache.memoized.Before(memoizationPeriod) { + return uss.concurrentUserStatsCache.stats, nil + } + + uss.concurrentUserStatsCache.stats = &concurrentUsersStats{} + err := uss.SQLStore.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + // Retrieves concurrent users stats as a histogram. Buckets are accumulative and upper bound is inclusive. + rawSQL := ` +SELECT + COUNT(CASE WHEN tokens <= 3 THEN 1 END) AS bucket_le_3, + COUNT(CASE WHEN tokens <= 6 THEN 1 END) AS bucket_le_6, + COUNT(CASE WHEN tokens <= 9 THEN 1 END) AS bucket_le_9, + COUNT(CASE WHEN tokens <= 12 THEN 1 END) AS bucket_le_12, + COUNT(CASE WHEN tokens <= 15 THEN 1 END) AS bucket_le_15, + COUNT(1) AS bucket_le_inf +FROM (select count(1) as tokens from user_auth_token group by user_id) uat;` + _, err := sess.SQL(rawSQL).Get(uss.concurrentUserStatsCache.stats) + if err != nil { + return err + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to get concurrent users stats from database: %w", err) + } + + uss.concurrentUserStatsCache.memoized = time.Now() + return uss.concurrentUserStatsCache.stats, nil +} diff --git a/pkg/infra/usagestats/types.go b/pkg/infra/usagestats/types.go new file mode 100644 index 00000000000..df1a7bf3bb8 --- /dev/null +++ b/pkg/infra/usagestats/types.go @@ -0,0 +1,10 @@ +package usagestats + +type concurrentUsersStats struct { + BucketLE3 int32 `xorm:"bucket_le_3"` + BucketLE6 int32 `xorm:"bucket_le_6"` + BucketLE9 int32 `xorm:"bucket_le_9"` + BucketLE12 int32 `xorm:"bucket_le_12"` + BucketLE15 int32 `xorm:"bucket_le_15"` + BucketLEInf int32 `xorm:"bucket_le_inf"` +} diff --git a/pkg/infra/usagestats/usage_stats.go b/pkg/infra/usagestats/usage_stats.go index 7ca719f376e..2fcf58f9b44 100644 --- a/pkg/infra/usagestats/usage_stats.go +++ b/pkg/infra/usagestats/usage_stats.go @@ -2,6 +2,7 @@ package usagestats import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -27,7 +28,7 @@ type UsageReport struct { Packaging string `json:"packaging"` } -func (uss *UsageStatsService) GetUsageReport() (UsageReport, error) { +func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport, error) { version := strings.ReplaceAll(setting.BuildVersion, ".", "_") metrics := map[string]interface{}{} @@ -185,6 +186,21 @@ func (uss *UsageStatsService) GetUsageReport() (UsageReport, error) { metrics["stats.auth_enabled."+authType+".count"] = enabledValue } + // Get concurrent users stats as histogram + concurrentUsersStats, err := uss.GetConcurrentUsersStats(ctx) + if err != nil { + metricsLogger.Error("Failed to get concurrent users stats", "error", err) + return report, err + } + + // Histogram is cumulative and metric name has a postfix of le_"" + metrics["stats.auth_token_per_user_le_3"] = concurrentUsersStats.BucketLE3 + metrics["stats.auth_token_per_user_le_6"] = concurrentUsersStats.BucketLE6 + metrics["stats.auth_token_per_user_le_9"] = concurrentUsersStats.BucketLE9 + metrics["stats.auth_token_per_user_le_12"] = concurrentUsersStats.BucketLE12 + metrics["stats.auth_token_per_user_le_15"] = concurrentUsersStats.BucketLE15 + metrics["stats.auth_token_per_user_le_inf"] = concurrentUsersStats.BucketLEInf + return report, nil } @@ -203,14 +219,14 @@ func (uss *UsageStatsService) RegisterMetric(name string, fn MetricFunc) { uss.externalMetrics[name] = fn } -func (uss *UsageStatsService) sendUsageStats() error { +func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error { if !setting.ReportingEnabled { return nil } metricsLogger.Debug(fmt.Sprintf("Sending anonymous usage stats to %s", usageStatsURL)) - report, err := uss.GetUsageReport() + report, err := uss.GetUsageReport(ctx) if err != nil { return err } diff --git a/pkg/infra/usagestats/usage_stats_service_test.go b/pkg/infra/usagestats/usage_stats_service_test.go new file mode 100644 index 00000000000..94149cf4f75 --- /dev/null +++ b/pkg/infra/usagestats/usage_stats_service_test.go @@ -0,0 +1,106 @@ +package usagestats + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "testing" + "time" + + "github.com/grafana/grafana/pkg/bus" + "github.com/grafana/grafana/pkg/services/licensing" + "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUsageStatsService_GetConcurrentUsersStats(t *testing.T) { + sqlStore := sqlstore.InitTestDB(t) + uss := &UsageStatsService{ + Bus: bus.New(), + SQLStore: sqlStore, + License: &licensing.OSSLicensingService{}, + } + + createConcurrentTokens(t, sqlStore) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(func() { + cancel() + }) + + actualResult, err := uss.GetConcurrentUsersStats(ctx) + require.NoError(t, err) + + expectedCachedResult := &concurrentUsersStats{ + BucketLE3: 1, + BucketLE6: 2, + BucketLE9: 3, + BucketLE12: 4, + BucketLE15: 5, + BucketLEInf: 6, + } + assert.Equal(t, expectedCachedResult, actualResult) + + createToken(t, 8, sqlStore) + require.NoError(t, err) + + actualResult, err = uss.GetConcurrentUsersStats(ctx) + require.NoError(t, err) + assert.Equal(t, expectedCachedResult, actualResult) +} + +func createToken(t *testing.T, uID int, sqlStore *sqlstore.SQLStore) { + t.Helper() + token, err := util.RandomHex(16) + require.NoError(t, err) + + tokenWithSecret := fmt.Sprintf("%ssecret%d", token, uID) + hashBytes := sha256.Sum256([]byte(tokenWithSecret)) + hashedToken := hex.EncodeToString(hashBytes[:]) + + now := time.Now().Unix() + + userAuthToken := userAuthToken{ + UserID: int64(uID), + AuthToken: hashedToken, + PrevAuthToken: hashedToken, + ClientIP: "192.168.10.11", + UserAgent: "Mozilla", + RotatedAt: now, + CreatedAt: now, + UpdatedAt: now, + SeenAt: 0, + AuthTokenSeen: false, + } + + err = sqlStore.WithDbSession(context.Background(), func(dbSession *sqlstore.DBSession) error { + _, err = dbSession.Insert(&userAuthToken) + return err + }) + require.NoError(t, err) +} + +func createConcurrentTokens(t *testing.T, sqlStore *sqlstore.SQLStore) { + t.Helper() + for u := 1; u <= 6; u++ { + for tkn := 1; tkn <= u*3; tkn++ { + createToken(t, u, sqlStore) + } + } +} + +type userAuthToken struct { + UserID int64 `xorm:"user_id"` + AuthToken string + PrevAuthToken string + UserAgent string + ClientIP string `xorm:"client_ip"` + AuthTokenSeen bool + SeenAt int64 + RotatedAt int64 + CreatedAt int64 + UpdatedAt int64 +} diff --git a/pkg/infra/usagestats/usage_stats_test.go b/pkg/infra/usagestats/usage_stats_test.go index 356fd6b108e..8a0acefbd55 100644 --- a/pkg/infra/usagestats/usage_stats_test.go +++ b/pkg/infra/usagestats/usage_stats_test.go @@ -2,6 +2,7 @@ package usagestats import ( "bytes" + "context" "errors" "io/ioutil" "runtime" @@ -159,6 +160,7 @@ func TestMetrics(t *testing.T) { return nil }) + createConcurrentTokens(t, uss.SQLStore) uss.AlertingUsageStats = &alertingUsageMock{} var wg sync.WaitGroup @@ -186,12 +188,12 @@ func TestMetrics(t *testing.T) { "grafana_com": true, } - err := uss.sendUsageStats() + err := uss.sendUsageStats(context.Background()) require.NoError(t, err) t.Run("Given reporting not enabled and sending usage stats", func(t *testing.T) { setting.ReportingEnabled = false - err := uss.sendUsageStats() + err := uss.sendUsageStats(context.Background()) require.NoError(t, err) t.Run("Should not gather stats or call http endpoint", func(t *testing.T) { @@ -212,7 +214,7 @@ func TestMetrics(t *testing.T) { setting.Packaging = "deb" wg.Add(1) - err := uss.sendUsageStats() + err := uss.sendUsageStats(context.Background()) require.NoError(t, err) t.Run("Should gather stats and call http endpoint", func(t *testing.T) { @@ -291,6 +293,13 @@ func TestMetrics(t *testing.T) { assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_grafana_com.count").MustInt()) assert.Equal(t, 1, metrics.Get("stats.packaging.deb.count").MustInt()) + + assert.Equal(t, 1, metrics.Get("stats.auth_token_per_user_le_3").MustInt()) + assert.Equal(t, 2, metrics.Get("stats.auth_token_per_user_le_6").MustInt()) + assert.Equal(t, 3, metrics.Get("stats.auth_token_per_user_le_9").MustInt()) + assert.Equal(t, 4, metrics.Get("stats.auth_token_per_user_le_12").MustInt()) + assert.Equal(t, 5, metrics.Get("stats.auth_token_per_user_le_15").MustInt()) + assert.Equal(t, 6, metrics.Get("stats.auth_token_per_user_le_inf").MustInt()) }) }) }) @@ -419,12 +428,26 @@ func TestMetrics(t *testing.T) { return nil }) + createConcurrentTokens(t, uss.SQLStore) + + t.Run("Should include metrics for concurrent users", func(t *testing.T) { + report, err := uss.GetUsageReport(context.Background()) + require.NoError(t, err) + + assert.Equal(t, int32(1), report.Metrics["stats.auth_token_per_user_le_3"]) + assert.Equal(t, int32(2), report.Metrics["stats.auth_token_per_user_le_6"]) + assert.Equal(t, int32(3), report.Metrics["stats.auth_token_per_user_le_9"]) + assert.Equal(t, int32(4), report.Metrics["stats.auth_token_per_user_le_12"]) + assert.Equal(t, int32(5), report.Metrics["stats.auth_token_per_user_le_15"]) + assert.Equal(t, int32(6), report.Metrics["stats.auth_token_per_user_le_inf"]) + }) + t.Run("Should include external metrics", func(t *testing.T) { uss.RegisterMetric(metricName, func() (interface{}, error) { return 1, nil }) - report, err := uss.GetUsageReport() + report, err := uss.GetUsageReport(context.Background()) assert.Nil(t, err, "Expected no error") metric := report.Metrics[metricName]