diff --git a/pkg/services/auth/idimpl/signer.go b/pkg/services/auth/idimpl/signer.go index 0a44026b7c5..ea6aaa8d920 100644 --- a/pkg/services/auth/idimpl/signer.go +++ b/pkg/services/auth/idimpl/signer.go @@ -10,14 +10,20 @@ import ( "github.com/grafana/grafana/pkg/services/signingkeys" ) +const idSignerKeyPrefix = "id" + var _ auth.IDSigner = (*LocalSigner)(nil) func ProvideLocalSigner(keyService signingkeys.Service) (*LocalSigner, error) { - key := keyService.GetServerPrivateKey() // FIXME: replace with signing specific key + id, key, err := keyService.GetOrCreatePrivateKey(context.Background(), idSignerKeyPrefix, jose.ES256) + if err != nil { + return nil, err + } + // FIXME: Handle key rotation signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{ ExtraHeaders: map[jose.HeaderKey]interface{}{ - "kid": "default", // FIXME: replace with specific key id + "kid": id, }, }) if err != nil { diff --git a/pkg/services/authn/clients/ext_jwt.go b/pkg/services/authn/clients/ext_jwt.go index 1932598122f..5e2f7e35b05 100644 --- a/pkg/services/authn/clients/ext_jwt.go +++ b/pkg/services/authn/clients/ext_jwt.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" "golang.org/x/exp/slices" @@ -172,7 +173,13 @@ func (s *ExtendedJWT) verifyRFC9068Token(ctx context.Context, rawToken string) ( } var claims ExtendedJWTClaims - err = parsedToken.Claims(s.signingKeys.GetServerPublicKey(), &claims) + _, key, err := s.signingKeys.GetOrCreatePrivateKey(ctx, + signingkeys.ServerPrivateKeyID, jose.ES256) + if err != nil { + return nil, fmt.Errorf("failed to get public key: %w", err) + } + + err = parsedToken.Claims(key.Public(), &claims) if err != nil { return nil, fmt.Errorf("failed to verify the signature: %w", err) } diff --git a/pkg/services/authn/clients/ext_jwt_test.go b/pkg/services/authn/clients/ext_jwt_test.go index 33361cb58b6..7a29ad6d96a 100644 --- a/pkg/services/authn/clients/ext_jwt_test.go +++ b/pkg/services/authn/clients/ext_jwt_test.go @@ -2,6 +2,7 @@ package clients import ( "context" + "crypto" "crypto/rand" "crypto/rsa" "fmt" @@ -12,17 +13,19 @@ import ( "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/models/roletype" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/oauthserver" "github.com/grafana/grafana/pkg/services/oauthserver/oastest" + "github.com/grafana/grafana/pkg/services/signingkeys" "github.com/grafana/grafana/pkg/services/signingkeys/signingkeystest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user/usertest" "github.com/grafana/grafana/pkg/setting" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var ( @@ -513,8 +516,9 @@ func setupTestCtx(t *testing.T, cfg *setting.Cfg) *testEnv { } } - signingKeysSvc := &signingkeystest.FakeSigningKeysService{} - signingKeysSvc.ExpectedServerPublicKey = &pk.PublicKey + signingKeysSvc := &signingkeystest.FakeSigningKeysService{ExpectedKeys: map[string]crypto.Signer{ + signingkeys.ServerPrivateKeyID: pk}, + } userSvc := &usertest.FakeUserService{} oauthSvc := &oastest.FakeService{} diff --git a/pkg/services/oauthserver/oasimpl/service.go b/pkg/services/oauthserver/oasimpl/service.go index 564c8e0f54b..4774243e33b 100644 --- a/pkg/services/oauthserver/oasimpl/service.go +++ b/pkg/services/oauthserver/oasimpl/service.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/go-jose/go-jose/v3" "github.com/ory/fosite" "github.com/ory/fosite/compose" "github.com/ory/fosite/storage" @@ -75,18 +76,6 @@ func ProvideService(router routing.RouteRegister, db db.DB, cfg *setting.Cfg, ScopeStrategy: fosite.WildcardScopeStrategy, } - privateKey := keySvc.GetServerPrivateKey() - - var publicKey any - switch k := privateKey.(type) { - case *rsa.PrivateKey: - publicKey = &k.PublicKey - case *ecdsa.PrivateKey: - publicKey = &k.PublicKey - default: - return nil, fmt.Errorf("unknown private key type %T", k) - } - s := &OAuth2ServiceImpl{ cache: localcache.New(cacheExpirationTime, cacheCleanupInterval), cfg: cfg, @@ -98,20 +87,20 @@ func ProvideService(router routing.RouteRegister, db db.DB, cfg *setting.Cfg, userService: userSvc, saService: svcAccSvc, teamService: teamSvc, - publicKey: publicKey, } api := api.NewAPI(router, s) api.RegisterAPIEndpoints() - s.oauthProvider = newProvider(config, s, privateKey) + s.oauthProvider = newProvider(config, s, keySvc) return s, nil } -func newProvider(config *fosite.Config, storage any, key any) fosite.OAuth2Provider { - keyGetter := func(context.Context) (any, error) { - return key, nil +func newProvider(config *fosite.Config, storage any, signingKeyService signingkeys.Service) fosite.OAuth2Provider { + keyGetter := func(ctx context.Context) (any, error) { + _, key, err := signingKeyService.GetOrCreatePrivateKey(ctx, signingkeys.ServerPrivateKeyID, jose.ES256) + return key, err } return compose.Compose( config, diff --git a/pkg/services/oauthserver/oasimpl/service_test.go b/pkg/services/oauthserver/oasimpl/service_test.go index 0fb441647da..325cc612002 100644 --- a/pkg/services/oauthserver/oasimpl/service_test.go +++ b/pkg/services/oauthserver/oasimpl/service_test.go @@ -2,6 +2,7 @@ package oasimpl import ( "context" + "crypto" "crypto/rand" "crypto/rsa" "encoding/base64" @@ -26,6 +27,7 @@ import ( "github.com/grafana/grafana/pkg/services/oauthserver/oastest" sa "github.com/grafana/grafana/pkg/services/serviceaccounts" satests "github.com/grafana/grafana/pkg/services/serviceaccounts/tests" + "github.com/grafana/grafana/pkg/services/signingkeys/signingkeystest" "github.com/grafana/grafana/pkg/services/team/teamtest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user/usertest" @@ -89,7 +91,13 @@ func setupTestEnv(t *testing.T) *TestEnv { teamService: env.TeamService, publicKey: &pk.PublicKey, } - env.S.oauthProvider = newProvider(config, env.S, pk) + + env.S.oauthProvider = newProvider(config, env.S, &signingkeystest.FakeSigningKeysService{ + ExpectedKeys: map[string]crypto.Signer{ + "default": pk, + }, + ExpectedError: nil, + }) return env } diff --git a/pkg/services/oauthserver/oasimpl/token_test.go b/pkg/services/oauthserver/oasimpl/token_test.go index 7b90b89e044..40cdf53b4de 100644 --- a/pkg/services/oauthserver/oasimpl/token_test.go +++ b/pkg/services/oauthserver/oasimpl/token_test.go @@ -609,7 +609,7 @@ func TestOAuth2ServiceImpl_HandleTokenRequest(t *testing.T) { env.S.HandleTokenRequest(resp, req) - require.Equal(t, tt.wantCode, resp.Code) + require.Equal(t, tt.wantCode, resp.Code, resp.Body.String()) if tt.wantCode != http.StatusOK { return } diff --git a/pkg/services/secrets/migrator/migrator.go b/pkg/services/secrets/migrator/migrator.go index 0bbdef5fe83..75ff7de755c 100644 --- a/pkg/services/secrets/migrator/migrator.go +++ b/pkg/services/secrets/migrator/migrator.go @@ -43,6 +43,7 @@ func ProvideSecretsMigrator( b64Secret{simpleSecret: simpleSecret{tableName: "secrets", columnName: "value"}, hasUpdatedColumn: true, encoding: base64.RawStdEncoding}, jsonSecret{tableName: "data_source"}, jsonSecret{tableName: "plugin_setting"}, + b64Secret{simpleSecret: simpleSecret{tableName: "signing_key", columnName: "private_key"}, encoding: base64.StdEncoding}, alertingSecret{}, } diff --git a/pkg/services/signingkeys/signingkeys.go b/pkg/services/signingkeys/signingkeys.go index 35ef7bd3b3a..bbb7ceba529 100644 --- a/pkg/services/signingkeys/signingkeys.go +++ b/pkg/services/signingkeys/signingkeys.go @@ -8,27 +8,21 @@ package signingkeys import ( + "context" "crypto" "github.com/go-jose/go-jose/v3" ) +const ( + ServerPrivateKeyID = "default" +) + // Service provides functionality for managing signing keys used to sign and verify JWT tokens. // // The service is under active development and is not yet ready for production use. type Service interface { // GetJWKS returns the JSON Web Key Set (JWKS) with all the keys that can be used to verify tokens (public keys) - GetJWKS() jose.JSONWebKeySet - // GetJWK returns the JSON Web Key (JWK) with the specified key ID which can be used to verify tokens (public key) - GetJWK(keyID string) (jose.JSONWebKey, error) - // GetPublicKey returns the public key with the specified key ID - GetPublicKey(keyID string) (crypto.PublicKey, error) - // GetPrivateKey returns the private key with the specified key ID - GetPrivateKey(keyID string) (crypto.PrivateKey, error) - // GetServerPrivateKey returns the private key used to sign tokens - GetServerPrivateKey() crypto.PrivateKey - // GetServerPublicKey returns the public key used to verify tokens - GetServerPublicKey() crypto.PublicKey - // AddPrivateKey adds a private key to the service - AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error + GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) + GetOrCreatePrivateKey(ctx context.Context, keyPrefix string, alg jose.SignatureAlgorithm) (string, crypto.Signer, error) } diff --git a/pkg/services/signingkeys/signingkeysimpl/service.go b/pkg/services/signingkeys/signingkeysimpl/service.go index 8dd51795c0d..eeeb026c8e6 100644 --- a/pkg/services/signingkeys/signingkeysimpl/service.go +++ b/pkg/services/signingkeys/signingkeysimpl/service.go @@ -1,37 +1,34 @@ package signingkeysimpl import ( + "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "errors" + "strings" + "time" "github.com/go-jose/go-jose/v3" + "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/signingkeys" -) - -const ( - serverPrivateKeyID = "default" + "github.com/grafana/grafana/pkg/services/signingkeys/signingkeystore" ) var _ signingkeys.Service = new(Service) -func ProvideEmbeddedSigningKeysService() (*Service, error) { +func ProvideEmbeddedSigningKeysService(dbStore db.DB, secretsService secrets.Service, + remoteCache remotecache.CacheStorage, +) (*Service, error) { s := &Service{ - log: log.New("auth.key_service"), - keys: map[string]crypto.Signer{}, - } - - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - s.log.Error("Error generating private key", "err", err) - return nil, signingkeys.ErrKeyGenerationFailed.Errorf("Error generating private key: %v", err) - } - - if err := s.AddPrivateKey(serverPrivateKeyID, privateKey); err != nil { - return nil, err + log: log.New("auth.key_service"), + store: signingkeystore.NewSigningKeyStore(dbStore, secretsService), + remoteCache: remoteCache, } return s, nil @@ -42,81 +39,49 @@ func ProvideEmbeddedSigningKeysService() (*Service, error) { // // The service is under active development and is not yet ready for production use. type Service struct { - log log.Logger - keys map[string]crypto.Signer + log log.Logger + store signingkeystore.SigningStore + remoteCache remotecache.CacheStorage } // GetJWKS returns the JSON Web Key Set (JWKS) with all the keys that can be used to verify tokens (public keys) -func (s *Service) GetJWKS() jose.JSONWebKeySet { - result := jose.JSONWebKeySet{} +func (s *Service) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) { + jwks, err := s.store.GetJWKS(ctx) + return jwks, err +} - for keyID := range s.keys { - // Skip error check because keyID must be a valid key ID - jwk, _ := s.GetJWK(keyID) - result.Keys = append(result.Keys, jwk) +// GetOrCreatePrivateKey returns the private key with the specified key ID. If the key does not exist, it will be +// created with the specified algorithm. +// The key will be automatically rotated at the beginning of each month. The previous key will be kept for 30 days. +func (s *Service) GetOrCreatePrivateKey(ctx context.Context, + keyPrefix string, alg jose.SignatureAlgorithm) (string, crypto.Signer, error) { + if alg != jose.ES256 { + s.log.Error("Only ES256 is supported", "alg", alg) + return "", nil, signingkeys.ErrKeyGenerationFailed.Errorf("Only ES256 is supported: %v", alg) } - return result -} + keyID := keyMonthScopedID(keyPrefix, alg) + signer, err := s.store.GetPrivateKey(ctx, keyID) + if err == nil { + return keyID, signer, nil + } + s.log.Debug("Private key not found, generating new key", "keyID", keyID, "err", err) -// GetJWK returns the JSON Web Key (JWK) with the specified key ID which can be used to verify tokens (public key) -func (s *Service) GetJWK(keyID string) (jose.JSONWebKey, error) { - privateKey, ok := s.keys[keyID] - if !ok { - s.log.Error("The specified key was not found", "keyID", keyID) - return jose.JSONWebKey{}, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID) + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + s.log.Error("Error generating private key", "err", err) + return "", nil, signingkeys.ErrKeyGenerationFailed.Errorf("Error generating private key: %v", err) } - result := jose.JSONWebKey{ - Key: privateKey.Public(), - Use: "sig", + expiry := time.Now().Add(30 * 24 * time.Hour) + if signer, err = s.store.AddPrivateKey(ctx, keyID, alg, privateKey, &expiry, false); err != nil && !errors.Is(err, signingkeys.ErrSigningKeyAlreadyExists) { + return "", nil, err } - return result, nil + return keyID, signer, nil } -// GetPublicKey returns the public key with the specified key ID -func (s *Service) GetPublicKey(keyID string) (crypto.PublicKey, error) { - privateKey, ok := s.keys[keyID] - if !ok { - s.log.Error("The specified key was not found", "keyID", keyID) - return nil, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID) - } - - return privateKey.Public(), nil -} - -// GetPrivateKey returns the private key with the specified key ID -func (s *Service) GetPrivateKey(keyID string) (crypto.PrivateKey, error) { - privateKey, ok := s.keys[keyID] - if !ok { - s.log.Error("The specified key was not found", "keyID", keyID) - return nil, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID) - } - - return privateKey, nil -} - -// AddPrivateKey adds a private key to the service -func (s *Service) AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error { - if _, ok := s.keys[keyID]; ok { - s.log.Error("The specified key ID is already in use", "keyID", keyID) - return signingkeys.ErrSigningKeyAlreadyExists.Errorf("The specified key ID is already in use: %s", keyID) - } - s.keys[keyID] = privateKey.(crypto.Signer) - return nil -} - -// GetServerPrivateKey returns the private key used to sign tokens -func (s *Service) GetServerPrivateKey() crypto.PrivateKey { - // The server private key is always available - pk, _ := s.GetPrivateKey(serverPrivateKeyID) - return pk -} - -// GetServerPrivateKey returns the private key used to sign tokens -func (s *Service) GetServerPublicKey() crypto.PublicKey { - // The server public key is always available - publicKey, _ := s.GetPublicKey(serverPrivateKeyID) - return publicKey +func keyMonthScopedID(keyPrefix string, alg jose.SignatureAlgorithm) string { + keyID := keyPrefix + "-" + time.Now().UTC().Format("2006-01") + "-" + strings.ToLower(string(alg)) + return keyID } diff --git a/pkg/services/signingkeys/signingkeysimpl/service_test.go b/pkg/services/signingkeys/signingkeysimpl/service_test.go index 1cae303b2e6..6ff776053f8 100644 --- a/pkg/services/signingkeys/signingkeysimpl/service_test.go +++ b/pkg/services/signingkeys/signingkeysimpl/service_test.go @@ -1,17 +1,22 @@ package signingkeysimpl import ( - "crypto" + "context" "crypto/ecdsa" "crypto/x509" "encoding/json" "encoding/pem" - "io" + "fmt" "testing" + "time" "github.com/go-jose/go-jose/v3" - "github.com/grafana/grafana/pkg/infra/log" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/signingkeys" + "github.com/grafana/grafana/pkg/services/signingkeys/signingkeystore" ) const ( @@ -28,84 +33,23 @@ func getPrivateKey(t *testing.T) *ecdsa.PrivateKey { return privateKey.(*ecdsa.PrivateKey) } -func setupTestService(t *testing.T) *Service { - svc := &Service{ - log: log.NewNopLogger(), - keys: map[string]crypto.Signer{serverPrivateKeyID: getPrivateKey(t)}, - } - return svc -} - -func TestEmbeddedKeyService_GetJWK(t *testing.T) { - tests := []struct { - name string - keyID string - want jose.JSONWebKey - wantErr bool - }{ - {name: "creates a JSON Web Key successfully", - keyID: "default", - want: jose.JSONWebKey{ - Key: getPrivateKey(t).Public(), - Use: "sig", - }, - wantErr: false, - }, - {name: "returns error when the specified key was not found", - keyID: "not-existing-key-id", - want: jose.JSONWebKey{}, - wantErr: true, - }, - } - svc := setupTestService(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := svc.GetJWK(tt.keyID) - if tt.wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - require.Equal(t, got, tt.want) - }) - } -} - -func TestEmbeddedKeyService_GetJWK_OnlyPublicKeyShared(t *testing.T) { - svc := setupTestService(t) - jwk, err := svc.GetJWK("default") - - require.NoError(t, err) - - jwkJson, err := jwk.MarshalJSON() - require.NoError(t, err) - - kvs := make(map[string]any) - err = json.Unmarshal(jwkJson, &kvs) - require.NoError(t, err) - - // check that the private key is not shared - require.NotContains(t, kvs, "d") - require.NotContains(t, kvs, "p") - require.NotContains(t, kvs, "q") -} - -func TestEmbeddedKeyService_GetJWKS(t *testing.T) { - svc := &Service{ - log: log.NewNopLogger(), - keys: map[string]crypto.Signer{ - serverPrivateKeyID: getPrivateKey(t), - "other": getPrivateKey(t), - }, - } - jwk := svc.GetJWKS() - - require.Equal(t, 2, len(jwk.Keys)) -} - func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) { - svc := setupTestService(t) - jwks := svc.GetJWKS() + mockStore := signingkeystore.NewFakeStore() + + _, err := mockStore.AddPrivateKey(context.Background(), signingkeys.ServerPrivateKeyID, jose.ES256, getPrivateKey(t), nil, false) + require.NoError(t, err) + + _, err = mockStore.AddPrivateKey(context.Background(), "other", jose.ES256, getPrivateKey(t), nil, false) + require.NoError(t, err) + + svc := &Service{ + log: log.NewNopLogger(), + store: mockStore, + } + jwks, err := svc.GetJWKS(context.Background()) + require.NoError(t, err) + + require.Equal(t, 2, len(jwks.Keys)) jwksJson, err := json.Marshal(jwks) require.NoError(t, err) @@ -115,6 +59,7 @@ func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) { } var kvs keys + err = json.Unmarshal(jwksJson, &kvs) require.NoError(t, err) @@ -126,120 +71,34 @@ func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) { } } -func TestEmbeddedKeyService_GetPublicKey(t *testing.T) { - tests := []struct { - name string - keyID string - want crypto.PublicKey - wantErr bool - }{ - { - name: "returns the public key successfully", - keyID: "default", - want: getPrivateKey(t).Public(), - wantErr: false, - }, - { - name: "returns error when the specified key was not found", - keyID: "not-existent-key-id", - want: nil, - wantErr: true, - }, - } - svc := setupTestService(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := svc.GetPublicKey(tt.keyID) - if tt.wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - require.Equal(t, got, tt.want) - }) - } -} +func TestEmbeddedKeyService_GetOrCreatePrivateKey(t *testing.T) { + mockStore := signingkeystore.NewFakeStore() -func TestEmbeddedKeyService_GetPrivateKey(t *testing.T) { - tests := []struct { - name string - keyID string - want crypto.PrivateKey - wantErr bool - }{ - { - name: "returns the private key successfully", - keyID: "default", - want: getPrivateKey(t), - wantErr: false, - }, - { - name: "returns error when the specified key was not found", - keyID: "not-existent-key-id", - want: nil, - wantErr: true, - }, + svc := &Service{ + log: log.NewNopLogger(), + store: mockStore, } - svc := setupTestService(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := svc.GetPrivateKey(tt.keyID) - if tt.wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - require.Equal(t, got, tt.want) - }) - } -} -func TestEmbeddedKeyService_AddPrivateKey(t *testing.T) { - tests := []struct { - name string - keyID string - wantErr bool - }{ - { - name: "adds the private key successfully", - keyID: "new-key-id", - wantErr: false, - }, - { - name: "returns error when the specified key is already in the store", - keyID: serverPrivateKeyID, - wantErr: true, - }, - } - svc := setupTestService(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := svc.AddPrivateKey(tt.keyID, &dummyPrivateKey{}) - if tt.wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - }) - } -} + wantedKeyID := keyMonthScopedID("test", jose.ES256) + assert.Equal(t, wantedKeyID, fmt.Sprintf("test-%s-es256", time.Now().UTC().Format("2006-01"))) -func TestProvideEmbeddedSigningKeysService(t *testing.T) { - s, err := ProvideEmbeddedSigningKeysService() + // only ES256 is supported + _, _, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.RS256) + require.Error(t, err) + + // first call should generate a key + _, key, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.ES256) require.NoError(t, err) - require.NotNil(t, s) + require.NotNil(t, key) - // Verify that ProvideEmbeddedSigningKeysService generates an ECDSA private key by default - require.IsType(t, &ecdsa.PrivateKey{}, s.GetServerPrivateKey()) -} + assert.Contains(t, mockStore.PrivateKeys, wantedKeyID) -type dummyPrivateKey struct { -} + // second call should return the same key + id, key2, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.ES256) + require.NoError(t, err) + require.NotNil(t, key2) + require.Equal(t, key, key2) + require.Equal(t, wantedKeyID, id) -func (d dummyPrivateKey) Public() crypto.PublicKey { - return "" -} - -func (d dummyPrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - return nil, nil + assert.Len(t, mockStore.PrivateKeys, 1) } diff --git a/pkg/services/signingkeys/signingkeystest/fake.go b/pkg/services/signingkeys/signingkeystest/fake.go index f3f1c4648e3..48620bfc0b1 100644 --- a/pkg/services/signingkeys/signingkeystest/fake.go +++ b/pkg/services/signingkeys/signingkeystest/fake.go @@ -1,54 +1,48 @@ package signingkeystest import ( + "context" "crypto" + "time" "github.com/go-jose/go-jose/v3" ) type FakeSigningKeysService struct { - ExpectedJSONWebKeySet jose.JSONWebKeySet - ExpectedJSONWebKey jose.JSONWebKey - ExpectedKeys map[string]crypto.Signer - ExpectedServerPrivateKey crypto.PrivateKey - ExpectedServerPublicKey crypto.PublicKey - ExpectedError error + ExpectedJSONWebKeySet jose.JSONWebKeySet + ExpectedJSONWebKey jose.JSONWebKey + ExpectedKeys map[string]crypto.Signer + ExpectedError error } -func (s *FakeSigningKeysService) GetJWKS() jose.JSONWebKeySet { - return s.ExpectedJSONWebKeySet -} - -// GetJWK returns the JSON Web Key (JWK) with the specified key ID which can be used to verify tokens (public key) -func (s *FakeSigningKeysService) GetJWK(keyID string) (jose.JSONWebKey, error) { - return s.ExpectedJSONWebKey, s.ExpectedError +func (s *FakeSigningKeysService) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) { + return s.ExpectedJSONWebKeySet, nil } // GetPublicKey returns the public key with the specified key ID -func (s *FakeSigningKeysService) GetPublicKey(keyID string) (crypto.PublicKey, error) { +func (s *FakeSigningKeysService) GetPublicKey(ctx context.Context, keyID string) (crypto.PublicKey, error) { return s.ExpectedKeys[keyID].Public(), s.ExpectedError } // GetPrivateKey returns the private key with the specified key ID -func (s *FakeSigningKeysService) GetPrivateKey(keyID string) (crypto.PrivateKey, error) { +func (s *FakeSigningKeysService) GetPrivateKey(ctx context.Context, keyID string) (crypto.PrivateKey, error) { return s.ExpectedKeys[keyID], s.ExpectedError } -// GetServerPrivateKey returns the private key used to sign tokens -func (s *FakeSigningKeysService) GetServerPrivateKey() crypto.PrivateKey { - return s.ExpectedServerPrivateKey -} - -// GetServerPublicKey returns the public key used to verify tokens -func (s *FakeSigningKeysService) GetServerPublicKey() crypto.PublicKey { - return s.ExpectedServerPublicKey -} - // AddPrivateKey adds a private key to the service -func (s *FakeSigningKeysService) AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error { +func (s *FakeSigningKeysService) AddPrivateKey(ctx context.Context, keyID string, + privateKey crypto.Signer, alg jose.SignatureAlgorithm, expiresAt *time.Time, force bool) error { if s.ExpectedError != nil { return s.ExpectedError } - s.ExpectedKeys[keyID] = privateKey.(crypto.Signer) + s.ExpectedKeys[keyID] = privateKey return nil } + +func (s *FakeSigningKeysService) GetOrCreatePrivateKey(ctx context.Context, + keyPrefix string, alg jose.SignatureAlgorithm) (string, crypto.Signer, error) { + if s.ExpectedError != nil { + return "", nil, s.ExpectedError + } + return keyPrefix, s.ExpectedKeys[keyPrefix], nil +} diff --git a/pkg/services/signingkeys/signingkeystore/fake.go b/pkg/services/signingkeys/signingkeystore/fake.go new file mode 100644 index 00000000000..63fed1841df --- /dev/null +++ b/pkg/services/signingkeys/signingkeystore/fake.go @@ -0,0 +1,62 @@ +package signingkeystore + +import ( + "context" + "crypto" + "fmt" + "time" + + "github.com/go-jose/go-jose/v3" +) + +type FakeStore struct { + PrivateKeys map[string]crypto.Signer + jwks jose.JSONWebKeySet +} + +func NewFakeStore() *FakeStore { + return &FakeStore{ + PrivateKeys: make(map[string]crypto.Signer), + jwks: jose.JSONWebKeySet{}, + } +} + +func (s *FakeStore) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) { + return s.jwks, nil +} + +func (s *FakeStore) AddPrivateKey(ctx context.Context, keyID string, alg jose.SignatureAlgorithm, + privateKey crypto.Signer, expiresAt *time.Time, force bool) (crypto.Signer, error) { + if !force { + if key, ok := s.PrivateKeys[keyID]; ok { + if !hasExpired(key) { + return nil, fmt.Errorf("key already exists and has not expired") + } + } + } + + s.PrivateKeys[keyID] = privateKey + + jwk := jose.JSONWebKey{ + Key: privateKey.Public(), + Algorithm: string(alg), + KeyID: keyID, + Use: "sig", + } + + s.jwks.Keys = append(s.jwks.Keys, jwk) + + return privateKey, nil +} + +func (s *FakeStore) GetPrivateKey(ctx context.Context, keyID string) (crypto.Signer, error) { + if key, ok := s.PrivateKeys[keyID]; ok { + return key, nil + } + + return nil, fmt.Errorf("key not found") +} + +func hasExpired(key crypto.Signer) bool { + return false +} diff --git a/pkg/services/signingkeys/signingkeystore/store.go b/pkg/services/signingkeys/signingkeystore/store.go new file mode 100644 index 00000000000..d89489be977 --- /dev/null +++ b/pkg/services/signingkeys/signingkeystore/store.go @@ -0,0 +1,215 @@ +package signingkeystore + +import ( + "context" + "crypto" + "crypto/x509" + "database/sql" + "encoding/base64" + "encoding/pem" + "errors" + "time" + + "github.com/go-jose/go-jose/v3" + + "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/services/secrets" + "github.com/grafana/grafana/pkg/services/signingkeys" + "github.com/grafana/grafana/pkg/services/sqlstore/session" +) + +type SigningStore interface { + // GetJWKS returns the JSON Web Key Set for the service + GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) + // AddPrivateKey adds a private key to the service. If the key already exists, it will be updated if force is true. + // If force is false, the key will only be updated if it has expired. If the key does not exist, it will be added. + // If expiresAt is nil, the key will not expire. Retrieve the result key with GetPrivateKey. + AddPrivateKey(ctx context.Context, keyID string, alg jose.SignatureAlgorithm, + privateKey crypto.Signer, expiresAt *time.Time, force bool) (crypto.Signer, error) + // GetPrivateKey returns the private key with the specified key ID + GetPrivateKey(ctx context.Context, keyID string) (crypto.Signer, error) +} + +var _ SigningStore = (*Store)(nil) + +type Store struct { + dbStore db.DB + secretsService secrets.Service +} + +type SigningKey struct { + ID int64 `json:"-" db:"id"` + KeyID string `json:"key_id" db:"key_id"` + PrivateKey []byte `json:"private_key" db:"private_key"` + AddedAt time.Time `json:"added_at" db:"added_at"` + ExpiresAt *time.Time `json:"expires_at" db:"expires_at"` + Alg jose.SignatureAlgorithm `json:"alg" db:"alg"` +} + +func NewSigningKeyStore(dbStore db.DB, secretsService secrets.Service) *Store { + return &Store{ + dbStore: dbStore, + secretsService: secretsService, + } +} + +// GetJWKS returns the JSON Web Key Set (JWKS) for the service. Expired keys will not be returned. +func (s *Store) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) { + keySet := jose.JSONWebKeySet{} + + keys := []*SigningKey{} + if err := s.dbStore.GetSqlxSession().Select(ctx, + &keys, "SELECT * FROM signing_key WHERE expires_at IS NULL OR expires_at > ?", time.Now()); err != nil { + return keySet, err + } + + for _, key := range keys { + assertedKey, err := s.decodePrivateKey(ctx, key) + if err != nil { + return keySet, err + } + + keySet.Keys = append(keySet.Keys, jose.JSONWebKey{ + Key: assertedKey.Public(), + Algorithm: string(key.Alg), + KeyID: key.KeyID, + Use: "sig", + }) + } + + return keySet, nil +} + +// AddPrivateKey adds a private key to the service. +func (s *Store) AddPrivateKey(ctx context.Context, + keyID string, alg jose.SignatureAlgorithm, privateKey crypto.Signer, expiresAt *time.Time, force bool) (crypto.Signer, error) { + privateKeyPEM, err := s.encodePrivateKey(ctx, privateKey) + if err != nil { + return nil, err + } + + key := &SigningKey{ + KeyID: keyID, + PrivateKey: privateKeyPEM, + AddedAt: time.Now(), + Alg: alg, + ExpiresAt: expiresAt, + } + + dbSession := s.dbStore.GetSqlxSession() + var signer crypto.Signer + err = dbSession.WithTransaction(ctx, func(tx *session.SessionTx) error { + existingKey := SigningKey{} + err := tx.Get(ctx, &existingKey, "SELECT * FROM signing_key WHERE key_id = ?", keyID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + if len(existingKey.PrivateKey) == 0 { + _, err = tx.Exec(ctx, + "INSERT INTO signing_key (key_id, private_key, added_at, alg, expires_at) VALUES (?, ?, ?, ?, ?)", + key.KeyID, key.PrivateKey, key.AddedAt, key.Alg, key.ExpiresAt) + signer = privateKey + return err + } + + if force || (existingKey.ExpiresAt != nil && existingKey.ExpiresAt.Before(time.Now())) { + _, err = tx.Exec(ctx, + "UPDATE signing_key SET private_key = ?, added_at = ?, alg = ?, expires_at = ? WHERE key_id = ?", + key.PrivateKey, key.AddedAt, key.Alg, key.ExpiresAt, key.KeyID) + signer = privateKey + return err + } + + signer, err = s.decodePrivateKey(ctx, &existingKey) + if err != nil { + return err + } + + return signingkeys.ErrSigningKeyAlreadyExists.Errorf("The specified key already exists: %s", keyID) + }) + return signer, err +} + +// GetPrivateKey returns the private key with the specified key ID. Expired keys will not be returned. +func (s *Store) GetPrivateKey(ctx context.Context, keyID string) (crypto.Signer, error) { + key := &SigningKey{} + err := s.dbStore.GetSqlxSession().Get(ctx, key, + "SELECT * FROM signing_key WHERE key_id = ?", keyID) + if err != nil { + return nil, err + } + + // Bail out if key has expired + if key.ExpiresAt != nil && key.ExpiresAt.Before(time.Now()) { + return nil, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID) + } + + signKey, err := s.decodePrivateKey(ctx, key) + if err != nil { + return nil, err + } + + return signKey, nil +} + +func (s *Store) encodePrivateKey(ctx context.Context, privateKey crypto.Signer) ([]byte, error) { + // Encode private key to binary format + pKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, err + } + + // Encode private key to PEM format + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: pKeyBytes, + }) + + encrypted, err := s.secretsService.Encrypt(ctx, privateKeyPEM, secrets.WithoutScope()) + if err != nil { + return nil, err + } + + encoded := make([]byte, base64.StdEncoding.EncodedLen(len(encrypted))) + base64.StdEncoding.Encode(encoded, encrypted) + return encoded, nil +} + +func (s *Store) decodePrivateKey(ctx context.Context, signingKey *SigningKey) (crypto.Signer, error) { + // Bail out if empty string since it'll cause a segfault in Decrypt + if len(signingKey.PrivateKey) == 0 { + return nil, errors.New("private key is empty") + } + + payload := make([]byte, base64.StdEncoding.DecodedLen(len(signingKey.PrivateKey))) + _, err := base64.StdEncoding.Decode(payload, signingKey.PrivateKey) + if err != nil { + return nil, err + } + + decrypted, err := s.secretsService.Decrypt(ctx, payload) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(decrypted) + if block == nil { + return nil, errors.New("failed to decode private key PEM") + } + + if block.Type != "PRIVATE KEY" { + return nil, errors.New("invalid block type") + } + + parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + assertedKey, ok := parsedKey.(crypto.Signer) + if !ok { + return nil, errors.New("failed to assert private key as crypto.Signer") + } + return assertedKey, nil +} diff --git a/pkg/services/signingkeys/signingkeystore/store_test.go b/pkg/services/signingkeys/signingkeystore/store_test.go new file mode 100644 index 00000000000..dd1088486f4 --- /dev/null +++ b/pkg/services/signingkeys/signingkeystore/store_test.go @@ -0,0 +1,199 @@ +package signingkeystore + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/grafana/grafana/pkg/infra/db" + "github.com/grafana/grafana/pkg/services/secrets/fakes" + "github.com/grafana/grafana/pkg/services/signingkeys" +) + +func TestIntegrationSigningKeyStore(t *testing.T) { + ctx := context.Background() + + testCases := []struct { + name string + keyFunc func() (crypto.Signer, error) + keyID string + alg jose.SignatureAlgorithm + expected jose.JSONWebKey + }{ + { + name: "RSA key", + keyFunc: func() (crypto.Signer, error) { + return rsa.GenerateKey(rand.Reader, 2048) + }, + keyID: "test-rsa-key", + alg: jose.RS256, + expected: jose.JSONWebKey{ + Key: &rsa.PublicKey{}, + Algorithm: "RS256", + KeyID: "test-rsa-key", + Use: "sig", + }, + }, + { + name: "Elliptic Curve key", + keyFunc: func() (crypto.Signer, error) { + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + }, + keyID: "test-ec-key", + alg: jose.ES256, + expected: jose.JSONWebKey{ + Key: &ecdsa.PublicKey{}, + Algorithm: "ES256", + KeyID: "test-ec-key", + Use: "sig", + }, + }, + } + + for _, tc := range testCases { + dbStore := db.InitTestDB(t) + secretSvc := fakes.NewFakeSecretsService() + store := NewSigningKeyStore(dbStore, secretSvc) + + t.Run(tc.name, func(t *testing.T) { + key, err := tc.keyFunc() + assert.NoError(t, err) + + _, err = store.AddPrivateKey(ctx, tc.keyID, tc.alg, key, nil, true) + assert.NoError(t, err) + + retrievedKey, err := store.GetPrivateKey(ctx, tc.keyID) + assert.NoError(t, err) + + assert.Equal(t, key.Public(), retrievedKey.Public()) + + jwks, err := store.GetJWKS(ctx) + assert.NoError(t, err) + + require.Len(t, jwks.Keys, 1) + assert.Equal(t, key.Public(), jwks.Keys[0].Key) + assert.Equal(t, tc.expected.Algorithm, jwks.Keys[0].Algorithm) + assert.Equal(t, tc.expected.KeyID, jwks.Keys[0].KeyID) + assert.Equal(t, tc.expected.Use, jwks.Keys[0].Use) + }) + } +} + +func TestIntegrationAddPrivateKey(t *testing.T) { + ctx := context.Background() + + dbStore := db.InitTestDB(t) + secretSvc := fakes.NewFakeSecretsService() + store := NewSigningKeyStore(dbStore, secretSvc) + + key1 := generateRSAKey(t) + key2 := generateECKey(t) + key3 := generateECKey(t) + + testCases := []struct { + name string + keyID string + alg jose.SignatureAlgorithm + privateKey crypto.Signer + expiresAt *time.Time + force bool + expectedErr error + expectedKey crypto.Signer + expectedGot crypto.Signer + }{ + { + name: "Add new private key", + keyID: "test-key-1", + alg: jose.RS256, + privateKey: key1, + force: false, + expectedKey: key1, + expectedGot: key1, + }, + { + name: "Add new private key with expiration", + keyID: "test-key-2", + alg: jose.ES256, + privateKey: key2, + expiresAt: &[]time.Time{time.Now().Add(24 * time.Hour)}[0], + force: false, + expectedKey: key2, + expectedGot: key2, + }, + { + name: "Fail to replace unexpired key", + keyID: "test-key-1", + alg: jose.RS256, + privateKey: key3, + expiresAt: &[]time.Time{time.Now().Add(-24 * time.Hour)}[0], + force: false, + expectedErr: signingkeys.ErrSigningKeyAlreadyExists, + expectedKey: key1, + expectedGot: key1, + }, + { + name: "Replace key1 private key with force, already expired", + keyID: "test-key-1", + alg: jose.ES256, + privateKey: key3, + expiresAt: &[]time.Time{time.Now().Add(-24 * time.Hour)}[0], + force: true, + expectedKey: nil, + expectedGot: key3, + }, + { + name: "Replace key1 private key with no force, is expired", + keyID: "test-key-1", + alg: jose.ES256, + privateKey: key1, + expiresAt: &[]time.Time{time.Now().Add(24 * time.Hour)}[0], + force: false, + expectedKey: nil, + expectedGot: key1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := store.AddPrivateKey(ctx, tc.keyID, tc.alg, tc.privateKey, tc.expiresAt, tc.force) + if tc.expectedErr != nil { + assert.ErrorIs(t, err, tc.expectedErr) + } else { + assert.NoError(t, err) + } + + if tc.expectedGot != nil { + assert.Equal(t, tc.expectedGot.Public(), got.Public()) + } else { + assert.Nil(t, got) + } + + if tc.expectedKey != nil { + retrievedKey, err := store.GetPrivateKey(ctx, tc.keyID) + assert.NoError(t, err) + assert.Equal(t, tc.expectedKey.Public(), retrievedKey.Public()) + } + }) + } +} + +func generateRSAKey(t *testing.T) *rsa.PrivateKey { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return key +} + +func generateECKey(t *testing.T) *ecdsa.PrivateKey { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return key +} diff --git a/pkg/services/sqlstore/migrations/migrations.go b/pkg/services/sqlstore/migrations/migrations.go index dd98d9822fc..56df4323cc9 100644 --- a/pkg/services/sqlstore/migrations/migrations.go +++ b/pkg/services/sqlstore/migrations/migrations.go @@ -5,6 +5,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore/migrations/accesscontrol" "github.com/grafana/grafana/pkg/services/sqlstore/migrations/anonservice" "github.com/grafana/grafana/pkg/services/sqlstore/migrations/oauthserver" + "github.com/grafana/grafana/pkg/services/sqlstore/migrations/signingkeys" "github.com/grafana/grafana/pkg/services/sqlstore/migrations/ualert" . "github.com/grafana/grafana/pkg/services/sqlstore/migrator" ) @@ -99,6 +100,7 @@ func (*OSSMigrations) AddMigration(mg *Migrator) { } anonservice.AddMigration(mg) + signingkeys.AddMigration(mg) } func addStarMigrations(mg *Migrator) { diff --git a/pkg/services/sqlstore/migrations/signingkeys/migrations.go b/pkg/services/sqlstore/migrations/signingkeys/migrations.go new file mode 100644 index 00000000000..ec4f087faf9 --- /dev/null +++ b/pkg/services/sqlstore/migrations/signingkeys/migrations.go @@ -0,0 +1,23 @@ +package signingkeys + +import "github.com/grafana/grafana/pkg/services/sqlstore/migrator" + +func AddMigration(mg *migrator.Migrator) { + var signingKeysV1 = migrator.Table{ + Name: "signing_key", + Columns: []*migrator.Column{ + {Name: "id", Type: migrator.DB_BigInt, IsPrimaryKey: true, IsAutoIncrement: true}, + {Name: "key_id", Type: migrator.DB_NVarchar, Length: 255, Nullable: false}, + {Name: "private_key", Type: migrator.DB_Text, Nullable: false}, + {Name: "added_at", Type: migrator.DB_DateTime, Nullable: false}, + {Name: "expires_at", Type: migrator.DB_DateTime, Nullable: true}, + {Name: "alg", Type: migrator.DB_NVarchar, Length: 255, Nullable: false}, + }, + Indices: []*migrator.Index{ + {Cols: []string{"key_id"}, Type: migrator.UniqueIndex}, + }, + } + + mg.AddMigration("create signing_key table", migrator.NewAddTableMigration(signingKeysV1)) + mg.AddMigration("add unique index signing_key.key_id", migrator.NewAddIndexMigration(signingKeysV1, signingKeysV1.Indices[0])) +}