Skip to content
246 changes: 246 additions & 0 deletions service/kas/key_indexer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package kas

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/opentdf/platform/protocol/go/policy"
"github.com/opentdf/platform/protocol/go/policy/kasregistry"
"github.com/opentdf/platform/sdk"
"github.com/opentdf/platform/service/logger"
"github.com/opentdf/platform/service/trust"
)

var ErrNoActiveKeyForAlgorithm = errors.New("no active key found for specified algorithm")

// Used for reaching out to platform to get keys
type KeyIndexer struct {
// SDK is the SDK instance used to interact with the platform
sdk *sdk.SDK
// KasURI
kasURI string
// Logger is the logger instance used for logging
log *logger.Logger
}

// platformKeyAdapter is an adapter for KeyDetails, where keys come from the platform
type KeyAdapter struct {
key *policy.KasKey
log *logger.Logger
}

func NewPlatformKeyIndexer(sdk *sdk.SDK, kasURI string, l *logger.Logger) *KeyIndexer {
return &KeyIndexer{
sdk: sdk,
kasURI: kasURI,
log: l,
}
}

func convertAlgToEnum(alg string) (policy.Algorithm, error) {
switch alg {
case "rsa:2048":
return policy.Algorithm_ALGORITHM_RSA_2048, nil
case "rsa:4096":
return policy.Algorithm_ALGORITHM_RSA_4096, nil
case "ec:secp256r1":
return policy.Algorithm_ALGORITHM_EC_P256, nil
case "ec:secp384r1":
return policy.Algorithm_ALGORITHM_EC_P384, nil
case "ec:secp521r1":
return policy.Algorithm_ALGORITHM_EC_P521, nil
default:
return policy.Algorithm_ALGORITHM_UNSPECIFIED, fmt.Errorf("unsupported algorithm: %s", alg)
}
}

func (p *KeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, _ bool) (trust.KeyDetails, error) {
alg, err := convertAlgToEnum(algorithm)
if err != nil {
return nil, err
}

req := &kasregistry.ListKeysRequest{
KeyAlgorithm: alg,
KasFilter: &kasregistry.ListKeysRequest_KasUri{
KasUri: p.kasURI,
},
}
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
if err != nil {
return nil, err
}

// Find active key.
var activeKey *policy.KasKey
for _, key := range resp.GetKasKeys() {
if key.GetKey().GetKeyStatus() == policy.KeyStatus_KEY_STATUS_ACTIVE {
activeKey = key
break
}
}
if activeKey == nil {
return nil, ErrNoActiveKeyForAlgorithm
}

return &KeyAdapter{
key: activeKey,
log: p.log,
}, nil
}

func (p *KeyIndexer) FindKeyByID(ctx context.Context, id trust.KeyIdentifier) (trust.KeyDetails, error) {
req := &kasregistry.GetKeyRequest{
Identifier: &kasregistry.GetKeyRequest_Key{
Key: &kasregistry.KasKeyIdentifier{
Identifier: &kasregistry.KasKeyIdentifier_Uri{
Uri: p.kasURI,
},
Kid: string(id),
},
},
}

resp, err := p.sdk.KeyAccessServerRegistry.GetKey(ctx, req)
if err != nil {
return nil, err
}

return &KeyAdapter{
key: resp.GetKasKey(),
log: p.log,
}, nil
}

func (p *KeyIndexer) ListKeys(ctx context.Context) ([]trust.KeyDetails, error) {
req := &kasregistry.ListKeysRequest{
KasFilter: &kasregistry.ListKeysRequest_KasUri{
KasUri: p.kasURI,
},
}
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
if err != nil {
return nil, err
}

keys := make([]trust.KeyDetails, len(resp.GetKasKeys()))
for i, key := range resp.GetKasKeys() {
keys[i] = &KeyAdapter{
key: key,
log: p.log,
}
}

return keys, nil
}

func (p *KeyAdapter) ID() trust.KeyIdentifier {
return trust.KeyIdentifier(p.key.GetKey().GetKeyId())
}

// Might need to convert this to a standard format
func (p *KeyAdapter) Algorithm() string {
return p.key.GetKey().GetKeyAlgorithm().String()
}

func (p *KeyAdapter) IsLegacy() bool {
return false
}

// This will point to the correct "manager"
func (p *KeyAdapter) System() string {
var mode string
if p.key.GetKey().GetProviderConfig() != nil {
mode = p.key.GetKey().GetProviderConfig().GetName()
}
return mode
}

func pemToPublicKey(publicPEM string) (*rsa.PublicKey, error) {
// Decode the PEM data
block, _ := pem.Decode([]byte(publicPEM))
if block == nil || block.Type != "PUBLIC KEY" {
return nil, errors.New("failed to decode PEM block or incorrect PEM type")
}

// Parse the public key
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}

// Assert type and return
rsaPub, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, errors.New("not an RSA public key")
}

return rsaPub, nil
}

// Repurpose of the StandardCrypto function
func rsaPublicKeyAsJSON(_ context.Context, publicPEM string) (string, error) {
pubKey, err := pemToPublicKey(publicPEM)
if err != nil {
return "", err
}

rsaPublicKeyJwk, err := jwk.FromRaw(pubKey)
if err != nil {
return "", fmt.Errorf("jwk.FromRaw: %w", err)
}

// Convert the public key to JSON format
pubKeyJSON, err := json.Marshal(rsaPublicKeyJwk)
if err != nil {
return "", err
}

return string(pubKeyJSON), nil
}

// Repurpose of the StandardCrypto function
func convertPEMToJWK(_ string) (string, error) {
return "", errors.New("convertPEMToJWK function is not implemented")
}

func (p *KeyAdapter) ExportPublicKey(ctx context.Context, format trust.KeyType) (string, error) {
publicKeyCtx := p.key.GetKey().GetPublicKeyCtx()

// Decode the base64-encoded public key
decodedPubKey, err := base64.StdEncoding.DecodeString(publicKeyCtx.GetPem())
if err != nil {
return "", err
}

switch format {
case trust.KeyTypeJWK:
// For JWK format (currently only supported for RSA)
if p.key.GetKey().GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 ||
p.key.GetKey().GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_4096 {
return rsaPublicKeyAsJSON(ctx, string(decodedPubKey))
}
// For EC keys, we return the public key in PEM format
jwkKey, err := convertPEMToJWK(string(decodedPubKey))
if err != nil {
return "", err
}

return jwkKey, nil
case trust.KeyTypePKCS8:
return string(decodedPubKey), nil
default:
return "", errors.New("unsupported key type")
}
}

func (p *KeyAdapter) ExportCertificate(_ context.Context) (string, error) {
return "", errors.New("not implemented")
}
76 changes: 76 additions & 0 deletions service/kas/key_indexer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package kas

import (
"context"
"testing"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/protocol/go/policy"
"github.com/opentdf/platform/service/trust"
"github.com/stretchr/testify/suite"
)

type KeyIndexTestSuite struct {
suite.Suite
rsaKey trust.KeyDetails
}

func (s *KeyIndexTestSuite) SetupTest() {
s.rsaKey = &KeyAdapter{
key: &policy.KasKey{
KasId: "test-kas-id",
Key: &policy.AsymmetricKey{
Id: "test-id",
KeyId: "test-key-id",
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
KeyMode: policy.KeyMode_KEY_MODE_CONFIG_ROOT_KEY,
PublicKeyCtx: &policy.KasPublicKeyCtx{
Pem: "LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQ0FROEFNSUlCQ2dLQ0FRRUF3SEw0TkVrOFpDa0JzNjZXQVpWagpIS3NseDRseWdmaXN3aW42RUx5OU9OczZLVDRYa1crRGxsdExtck14bHZkbzVRaDg1UmFZS01mWUdDTWtPM0dGCkFsK0JOeWFOM1kwa0N1QjNPU2ErTzdyMURhNVZteVVuaEJNbFBrYnVPY1Y0cjlLMUhOSGd3eDl2UFp3RjRpQW8KQStEY1VBcWFEeHlvYjV6enNGZ0hUNjJHLzdLdEtiZ2hYT1dCanRUYUl1ZHpsK2FaSjFPemY0U1RkOXhST2QrMQordVo2VG1ocmFEUm9zdDUrTTZUN0toL2lGWk40TTFUY2hwWXU1TDhKR2tVaG9YaEdZcHUrMGczSzlqYlh6RVh5CnpJU3VXN2d6SGRWYUxvcnBkQlNkRHpOWkNvTFVoL0U1T3d5TFZFQkNKaDZJVUtvdWJ5WHVucnIxQnJmK2tLbEsKeHdJREFRQUIKLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==",
},
ProviderConfig: &policy.KeyProviderConfig{
Id: "test-provider-id",
Name: "openbao",
},
},
},
}
}
func (s *KeyIndexTestSuite) TearDownTest() {}

func (s *KeyIndexTestSuite) TestKeyDetails() {
s.Equal("test-key-id", string(s.rsaKey.ID()))
s.Equal("ALGORITHM_RSA_2048", s.rsaKey.Algorithm())
s.False(s.rsaKey.IsLegacy())
s.Equal("openbao", s.rsaKey.System())
}

func (s *KeyIndexTestSuite) TestKeyExportPublicKey_JWKFormat() {
// Export JWK format
jwkString, err := s.rsaKey.ExportPublicKey(context.Background(), trust.KeyTypeJWK)
s.Require().NoError(err)
s.Require().NotEmpty(jwkString)

rsaKey, err := jwk.ParseKey([]byte(jwkString))
s.Require().NoError(err)
s.Require().NotNil(rsaKey)
}

func (s *KeyIndexTestSuite) TestKeyExportPublicKey_PKCSFormat() {
// Export JWK format
pem, err := s.rsaKey.ExportPublicKey(context.Background(), trust.KeyTypePKCS8)
s.Require().NoError(err)
s.Require().NotEmpty(pem)

keyAdapter, ok := s.rsaKey.(*KeyAdapter)
s.Require().True(ok)
pubCtx := keyAdapter.key.GetKey().GetPublicKeyCtx()
s.Require().NotEmpty(pubCtx)
base64Pem := ocrypto.Base64Encode([]byte(pem))
s.Equal(pubCtx.GetPem(), string(base64Pem))
}

func TestNewPlatformKeyIndexTestSuite(t *testing.T) {
suite.Run(t, new(KeyIndexTestSuite))
}
Loading