Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/ocrypto/protected_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func TestAESProtectedKey_Export_EncapsulatorError(t *testing.T) {
protectedKey, err := NewAESProtectedKey(key)
require.NoError(t, err)

// Since Export now calls Encrypt, make Encrypt return an error
mockEncapsulator := &mockEncapsulator{
encryptFunc: func(_ []byte) ([]byte, error) {
return nil, assert.AnError
Expand Down
37 changes: 25 additions & 12 deletions service/internal/security/basic_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (b *BasicManager) Name() string {
return BasicManagerName
}

func (b *BasicManager) Decrypt(ctx context.Context, keyDetails trust.KeyDetails, ciphertext []byte, ephemeralPublicKey []byte) (trust.ProtectedKey, error) {
func (b *BasicManager) Decrypt(ctx context.Context, keyDetails trust.KeyDetails, ciphertext []byte, ephemeralPublicKey []byte) (ocrypto.ProtectedKey, error) {
// Implementation of Decrypt method

// Get Private Key
Expand All @@ -74,7 +74,11 @@ func (b *BasicManager) Decrypt(ctx context.Context, keyDetails trust.KeyDetails,
if err != nil {
return nil, fmt.Errorf("failed to decrypt with RSA: %w", err)
}
return NewInProcessAESKey(plaintext), nil
protectedKey, err := ocrypto.NewAESProtectedKey(plaintext)
if err != nil {
return nil, fmt.Errorf("failed to create protected key: %w", err)
}
return protectedKey, nil
case ocrypto.EC256Key, ocrypto.EC384Key, ocrypto.EC521Key:
ecPrivKey, err := ocrypto.ECPrivateKeyFromPem(privKey)
if err != nil {
Expand All @@ -88,13 +92,17 @@ func (b *BasicManager) Decrypt(ctx context.Context, keyDetails trust.KeyDetails,
if err != nil {
return nil, fmt.Errorf("failed to decrypt with ephemeral key: %w", err)
}
return NewInProcessAESKey(plaintext), nil
protectedKey, err := ocrypto.NewAESProtectedKey(plaintext)
if err != nil {
return nil, fmt.Errorf("failed to create protected key: %w", err)
}
return protectedKey, nil
}

return nil, fmt.Errorf("unsupported algorithm: %s", keyDetails.Algorithm())
}

func (b *BasicManager) DeriveKey(ctx context.Context, keyDetails trust.KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (trust.ProtectedKey, error) {
func (b *BasicManager) DeriveKey(ctx context.Context, keyDetails trust.KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (ocrypto.ProtectedKey, error) {
// Implementation of DeriveKey method
privateKeyCtx, err := keyDetails.ExportPrivateKey(ctx)
if err != nil {
Expand Down Expand Up @@ -130,22 +138,27 @@ func (b *BasicManager) DeriveKey(ctx context.Context, keyDetails trust.KeyDetail
if err != nil {
return nil, fmt.Errorf("failed to calculate HKDF: %w", err)
}
return NewInProcessAESKey(key), nil
protectedKey, err := ocrypto.NewAESProtectedKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create protected key: %w", err)
}
return protectedKey, nil
}

type OCEncapsulator struct {
ocrypto.PublicKeyEncryptor
}

func (e *OCEncapsulator) Encapsulate(dek trust.ProtectedKey) ([]byte, error) {
ipk, ok := dek.(*InProcessAESKey)
if !ok {
return nil, errors.New("invalid DEK type for encapsulation")
}
return e.Encrypt(ipk.rawKey)
func (e *OCEncapsulator) Encapsulate(dek ocrypto.ProtectedKey) ([]byte, error) {
// Delegate to the ProtectedKey to avoid exposing raw key material
return dek.Export(e)
}

func (e *OCEncapsulator) PublicKeyAsPEM() (string, error) {
return e.PublicKeyEncryptor.PublicKeyInPemFormat()
}

func (b *BasicManager) GenerateECSessionKey(_ context.Context, ephemeralPublicKey string) (trust.Encapsulator, error) {
func (b *BasicManager) GenerateECSessionKey(_ context.Context, ephemeralPublicKey string) (ocrypto.Encapsulator, error) {
pke, err := ocrypto.FromPublicPEMWithSalt(ephemeralPublicKey, NanoVersionSalt(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create public key encryptor: %w", err)
Expand Down
34 changes: 30 additions & 4 deletions service/internal/security/basic_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ func (m *MockEncapsulator) EphemeralKey() []byte {
return nil
}

// noOpEncapsulator is a test encapsulator that returns raw key data without encryption
type noOpEncapsulator struct{}

func (n *noOpEncapsulator) Encapsulate(pk ocrypto.ProtectedKey) ([]byte, error) {
// Delegate to ProtectedKey to avoid accessing raw key directly
return pk.Export(n)
}

func (n *noOpEncapsulator) Encrypt(data []byte) ([]byte, error) {
return data, nil
}

func (n *noOpEncapsulator) PublicKeyAsPEM() (string, error) {
return "", nil
}

func (n *noOpEncapsulator) EphemeralKey() []byte {
return nil
}

// Helper function to wrap a key with AES-GCM
func wrapKeyWithAESGCM(keyToWrap []byte, rootKey []byte) (string, error) {
gcm, err := ocrypto.NewAESGcm(rootKey)
Expand Down Expand Up @@ -276,7 +296,7 @@ func TestBasicManager_Decrypt(t *testing.T) {
bm, err := NewBasicManager(log, testCache, rootKeyHex)
require.NoError(t, err)

samplePayload := []byte("secret payload")
samplePayload := []byte("secret payload16") // 16 bytes for valid AES key

t.Run("successful RSA decryption", func(t *testing.T) {
mockDetails := new(MockKeyDetails)
Expand All @@ -298,7 +318,9 @@ func TestBasicManager_Decrypt(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, protectedKey)

decryptedPayload, err := protectedKey.Export(nil)
// Use noOpEncapsulator to get raw key data for testing
noOpEnc := &noOpEncapsulator{}
decryptedPayload, err := protectedKey.Export(noOpEnc)
require.NoError(t, err)
assert.Equal(t, samplePayload, decryptedPayload)
})
Expand All @@ -324,7 +346,9 @@ func TestBasicManager_Decrypt(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, protectedKey)

decryptedPayload, err := protectedKey.Export(nil)
// Use noOpEncapsulator to get raw key data for testing
noOpEnc := &noOpEncapsulator{}
decryptedPayload, err := protectedKey.Export(noOpEnc)
require.NoError(t, err)
assert.Equal(t, samplePayload, decryptedPayload)
})
Expand Down Expand Up @@ -445,7 +469,9 @@ func TestBasicManager_DeriveKey(t *testing.T) {
expectedDerivedKey, err := ocrypto.CalculateHKDF(NanoVersionSalt(), expectedSharedSecret)
require.NoError(t, err)

actualDerivedKey, err := protectedKey.Export(nil)
// Use noOpEncapsulator to get raw key data for testing
noOpEnc := &noOpEncapsulator{}
actualDerivedKey, err := protectedKey.Export(noOpEnc)
require.NoError(t, err)
assert.Equal(t, expectedDerivedKey, actualDerivedKey)
})
Expand Down
114 changes: 22 additions & 92 deletions service/internal/security/in_process_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"crypto"
"crypto/elliptic"
"crypto/hmac"
"crypto/sha256"
"errors"
"fmt"
"log/slog"
Expand All @@ -17,88 +15,6 @@ import (

const inProcessSystemName = "opentdf.io/in-process"

// InProcessAESKey implements the trust.ProtectedKey interface with an in-memory secret key
type InProcessAESKey struct {
rawKey []byte
logger *slog.Logger
}

var _ trust.ProtectedKey = (*InProcessAESKey)(nil)

// NewInProcessAESKey creates a new instance of StandardUnwrappedKey
func NewInProcessAESKey(rawKey []byte) *InProcessAESKey {
return &InProcessAESKey{
rawKey: rawKey,
logger: slog.Default(),
}
}

func (k *InProcessAESKey) DecryptAESGCM(iv []byte, body []byte, tagSize int) ([]byte, error) {
aesGcm, err := ocrypto.NewAESGcm(k.rawKey)
if err != nil {
return nil, err
}

decryptedData, err := aesGcm.DecryptWithIVAndTagSize(iv, body, tagSize)
if err != nil {
return nil, err
}

return decryptedData, nil
}

// Export returns the raw key data, optionally encrypting it with the provided trust.Encapsulator
func (k *InProcessAESKey) Export(encapsulator trust.Encapsulator) ([]byte, error) {
if encapsulator == nil {
if k.logger != nil {
k.logger.Warn("exporting raw key data without encryption")
}
return k.rawKey, nil
}

// If an encryptor is provided, encrypt the key data before returning
encryptedKey, err := encapsulator.Encapsulate(k)
if err != nil {
if k.logger != nil {
k.logger.Warn("failed to encrypt key data for export", slog.Any("err", err))
}
return nil, err
}

return encryptedKey, nil
}

// VerifyBinding checks if the policy binding matches the given policy data
func (k *InProcessAESKey) VerifyBinding(ctx context.Context, policy, policyBinding []byte) error {
if len(k.rawKey) == 0 {
return errors.New("key data is empty")
}

actualHMAC, err := k.generateHMACDigest(ctx, policy)
if err != nil {
return fmt.Errorf("unable to generate policy hmac: %w", err)
}

if !hmac.Equal(actualHMAC, policyBinding) {
return errors.New("policy hmac mismatch")
}

return nil
}

// generateHMACDigest is a helper to generate an HMAC digest from a message using the key
func (k *InProcessAESKey) generateHMACDigest(ctx context.Context, msg []byte) ([]byte, error) {
mac := hmac.New(sha256.New, k.rawKey)
_, err := mac.Write(msg)
if err != nil {
if k.logger != nil {
k.logger.WarnContext(ctx, "failed to compute hmac")
}
return nil, errors.New("policy hmac")
}
return mac.Sum(nil), nil
}

func convertPEMToJWK(_ string) (string, error) {
// Implement the conversion logic here or use an external library if available.
// For now, return a placeholder error to indicate the function is not implemented.
Expand Down Expand Up @@ -304,9 +220,12 @@ func (a *InProcessProvider) ListKeysWith(ctx context.Context, opts trust.ListKey
}

// Decrypt implements the unified decryption method for both RSA and EC
func (a *InProcessProvider) Decrypt(ctx context.Context, keyDetails trust.KeyDetails, ciphertext []byte, ephemeralPublicKey []byte) (trust.ProtectedKey, error) {
func (a *InProcessProvider) Decrypt(ctx context.Context, keyDetails trust.KeyDetails, ciphertext []byte, ephemeralPublicKey []byte) (ocrypto.ProtectedKey, error) {
kid := string(keyDetails.ID())

var protectedKey ocrypto.ProtectedKey
var err error

// Try to determine the key type
keyType, err := a.determineKeyType(ctx, kid)
if err != nil {
Expand All @@ -325,7 +244,7 @@ func (a *InProcessProvider) Decrypt(ctx context.Context, keyDetails trust.KeyDet
if len(ephemeralPublicKey) == 0 {
return nil, errors.New("ephemeral public key is required for EC decryption")
}
rawKey, err = a.cryptoProvider.ECDecrypt(ctx, kid, ephemeralPublicKey, ciphertext)
protectedKey, err = a.cryptoProvider.ECDecrypt(ctx, kid, ephemeralPublicKey, ciphertext)

default:
return nil, errors.New("unsupported key algorithm")
Expand All @@ -335,16 +254,27 @@ func (a *InProcessProvider) Decrypt(ctx context.Context, keyDetails trust.KeyDet
return nil, err
}

return &InProcessAESKey{
rawKey: rawKey,
logger: a.logger,
}, nil
if protectedKey == nil {
protectedKey, err = ocrypto.NewAESProtectedKey(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to create protected key: %w", err)
}
}

return protectedKey, nil
}

// DeriveKey generates a symmetric key for NanoTDF
func (a *InProcessProvider) DeriveKey(_ context.Context, keyDetails trust.KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (trust.ProtectedKey, error) {
func (a *InProcessProvider) DeriveKey(_ context.Context, keyDetails trust.KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (ocrypto.ProtectedKey, error) {
k, err := a.cryptoProvider.GenerateNanoTDFSymmetricKey(string(keyDetails.ID()), ephemeralPublicKeyBytes, curve)
return NewInProcessAESKey(k), err
if err != nil {
return nil, err
}
protectedKey, err := ocrypto.NewAESProtectedKey(k)
if err != nil {
return nil, fmt.Errorf("failed to create protected key: %w", err)
}
return protectedKey, nil
}

// GenerateECSessionKey generates a session key for NanoTDF
Expand Down
12 changes: 5 additions & 7 deletions service/internal/security/standard_crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,17 @@ func NanoVersionSalt() []byte {
}

// ECDecrypt uses hybrid ECIES to decrypt the data.
func (s *StandardCrypto) ECDecrypt(ctx context.Context, keyID string, ephemeralPublicKey, ciphertext []byte) ([]byte, error) {
func (s *StandardCrypto) ECDecrypt(ctx context.Context, keyID string, ephemeralPublicKey, ciphertext []byte) (ocrypto.ProtectedKey, error) {
unwrappedKey, err := s.Decrypt(ctx, trust.KeyIdentifier(keyID), ciphertext, ephemeralPublicKey)
if err != nil {
return nil, err
}
return unwrappedKey.Export(nil)

return unwrappedKey, nil
}

// Decrypt implements the SecurityProvider Decrypt method
func (s *StandardCrypto) Decrypt(_ context.Context, keyID trust.KeyIdentifier, ciphertext []byte, ephemeralPublicKey []byte) (trust.ProtectedKey, error) {
func (s *StandardCrypto) Decrypt(_ context.Context, keyID trust.KeyIdentifier, ciphertext []byte, ephemeralPublicKey []byte) (ocrypto.ProtectedKey, error) {
kid := string(keyID)
ska, ok := s.keysByID[kid]
if !ok {
Expand Down Expand Up @@ -488,8 +489,5 @@ func (s *StandardCrypto) Decrypt(_ context.Context, keyID trust.KeyIdentifier, c
return nil, fmt.Errorf("unsupported key type for key ID [%s]", kid)
}

return &InProcessAESKey{
rawKey: rawKey,
logger: slog.Default(),
}, nil
return ocrypto.NewAESProtectedKey(rawKey)
}
4 changes: 2 additions & 2 deletions service/kas/access/publicKey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ func (m *MockSecurityProvider) ListKeysWith(_ context.Context, opts trust.ListKe
return keys, nil
}

func (m *MockSecurityProvider) Decrypt(_ context.Context, _ trust.KeyDetails, _, _ []byte) (trust.ProtectedKey, error) {
func (m *MockSecurityProvider) Decrypt(_ context.Context, _ trust.KeyDetails, _, _ []byte) (ocrypto.ProtectedKey, error) {
return nil, errors.New("not implemented for tests")
}

func (m *MockSecurityProvider) DeriveKey(_ context.Context, _ trust.KeyDetails, _ []byte, _ elliptic.Curve) (trust.ProtectedKey, error) {
func (m *MockSecurityProvider) DeriveKey(_ context.Context, _ trust.KeyDetails, _ []byte, _ elliptic.Curve) (ocrypto.ProtectedKey, error) {
return nil, errors.New("not implemented for tests")
}

Expand Down
8 changes: 4 additions & 4 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type entityInfo struct {

type kaoResult struct {
ID string
DEK trust.ProtectedKey
DEK ocrypto.ProtectedKey
Encapped []byte
Error error

Expand Down Expand Up @@ -482,7 +482,7 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
continue
}

var dek trust.ProtectedKey
var dek ocrypto.ProtectedKey
var err error
switch kao.GetKeyAccessObject().GetKeyType() {
case "ec-wrapped":
Expand Down Expand Up @@ -827,7 +827,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
failAllKaos(requests, results, err400("keypair mismatch"))
return "", results
}
sessionKeyPEM, err := sessionKey.PublicKeyInPemFormat()
sessionKeyPEM, err := sessionKey.PublicKeyAsPEM()
if err != nil {
p.Logger.WarnContext(ctx, "failure in PublicKeyToPem", slog.Any("error", err))
failAllKaos(requests, results, err500(""))
Expand Down Expand Up @@ -955,7 +955,7 @@ func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.Unsi
return nil, results
}

func extractNanoPolicy(symmetricKey trust.ProtectedKey, header sdk.NanoTDFHeader) (*Policy, error) {
func extractNanoPolicy(symmetricKey ocrypto.ProtectedKey, header sdk.NanoTDFHeader) (*Policy, error) {
const (
kIvLen = 12
)
Expand Down
Loading
Loading