Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
15e7968
feat(sdk): Enable base key support.
c-r33d Jun 10, 2025
1e3acd3
changes.
c-r33d Jun 10, 2025
3e8b84b
check,
c-r33d Jun 10, 2025
914a069
change public_key.
c-r33d Jun 10, 2025
4cad58f
feat(sdk): Base key support.
c-r33d Jun 11, 2025
d7c56a8
Merge branch 'main' into feat/DSPX-1132-base-keys-sdk
c-r33d Jun 11, 2025
01f474a
upgrade go.
c-r33d Jun 11, 2025
85b398e
upgrade go.
c-r33d Jun 11, 2025
1d38d46
linting.
c-r33d Jun 11, 2025
98096a0
linting.
c-r33d Jun 11, 2025
e790de5
remove file.
c-r33d Jun 11, 2025
13ecb7e
linting
c-r33d Jun 11, 2025
b22e072
fix tests.
c-r33d Jun 11, 2025
ef61431
tidy
c-r33d Jun 11, 2025
af54977
test.
c-r33d Jun 12, 2025
594e68a
Merge branch 'main' into feat/DSPX-1132-base-keys-sdk
c-r33d Jun 12, 2025
ad667fb
Merge branch 'main' into feat/DSPX-1132-base-keys-sdk
c-r33d Jun 12, 2025
f8fea08
fix small keys bug.
c-r33d Jun 13, 2025
b7d9a28
linting.
c-r33d Jun 13, 2025
d35e9d6
fix underflow?
c-r33d Jun 13, 2025
e9bfc93
fix tests.
c-r33d Jun 13, 2025
d6d63cb
fix linting./
c-r33d Jun 13, 2025
4711f7c
fix conditional.
c-r33d Jun 16, 2025
d072160
fix conditional.
c-r33d Jun 16, 2025
e84d98d
fix conditional.
c-r33d Jun 16, 2025
9cf2de7
fix test.
c-r33d Jun 16, 2025
eb5dad7
refactor.
c-r33d Jun 16, 2025
c00a90f
refactor.
c-r33d Jun 16, 2025
5ed7eb6
refactor.
c-r33d Jun 16, 2025
c5922f3
update.
c-r33d Jun 16, 2025
c3c5d17
update to enum.
c-r33d Jun 16, 2025
8d76873
linting
c-r33d Jun 16, 2025
29e47ac
feat(sdk): Support pulling keys based on kid.
c-r33d Jun 17, 2025
61d97a9
change.
c-r33d Jun 17, 2025
4582441
Merge branch 'main' into feat/DSPX-1266-key-splits-kids
c-r33d Jun 18, 2025
dc2294d
merge conflicts.
c-r33d Jun 18, 2025
016bda8
merge conflict.
c-r33d Jun 18, 2025
a4b2b7b
fix issue where roundtrip failed from split plan.
c-r33d Jun 18, 2025
f3857d8
iterate through map.
c-r33d Jun 23, 2025
d539740
kid not defined.
c-r33d Jun 23, 2025
3e71e6a
fix test.
c-r33d Jun 23, 2025
ff605f6
fix test.
c-r33d Jun 23, 2025
cfe67a1
change order.
c-r33d Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions sdk/basekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ func getBaseKey(ctx context.Context, s SDK) (*policy.SimpleKasKey, error) {

baseKeyStructure, ok := configMap[baseKeyWellKnown]
if !ok {
return nil, errBaseKeyNotFound
return nil, ErrBaseKeyNotFound
}

baseKeyMap, ok := baseKeyStructure.(map[string]interface{})
if !ok {
return nil, errBaseKeyInvalidFormat
return nil, ErrBaseKeyInvalidFormat
}

simpleKasKey, err := parseSimpleKasKey(baseKeyMap)
Expand All @@ -97,28 +97,28 @@ func parseSimpleKasKey(baseKeyMap map[string]interface{}) (*policy.SimpleKasKey,
simpleKasKey := &policy.SimpleKasKey{}

if len(baseKeyMap) == 0 {
return nil, errBaseKeyEmpty
return nil, ErrBaseKeyEmpty
}

publicKey, ok := baseKeyMap[baseKeyPublicKey].(map[string]interface{})
if !ok {
return nil, errBaseKeyInvalidFormat
return nil, ErrBaseKeyInvalidFormat
}

alg, ok := publicKey[baseKeyAlg].(string)
if !ok {
return nil, errBaseKeyInvalidFormat
return nil, ErrBaseKeyInvalidFormat
}
publicKey[baseKeyAlg] = getKasKeyAlg(alg)
baseKeyMap[baseKeyPublicKey] = publicKey
configJSON, err := json.Marshal(baseKeyMap)
if err != nil {
return nil, errors.Join(errMarshalBaseKeyFailed, err)
return nil, errors.Join(ErrMarshalBaseKeyFailed, err)
}

err = protojson.Unmarshal(configJSON, simpleKasKey)
if err != nil {
return nil, errors.Join(errUnmarshalBaseKeyFailed, err)
return nil, errors.Join(ErrUnmarshalBaseKeyFailed, err)
}
return simpleKasKey, nil
}
10 changes: 5 additions & 5 deletions sdk/basekey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (s *BaseKeyTestSuite) TestGetBaseKeyMissingBaseKey() {
s.Require().True(mockService.called)
s.Require().Error(err)
s.Require().Nil(baseKey)
s.Require().Contains(err.Error(), errBaseKeyNotFound.Error())
s.Require().ErrorIs(err, ErrBaseKeyNotFound)
}

func (s *BaseKeyTestSuite) TestGetBaseKeyInvalidBaseKeyFormat() {
Expand All @@ -267,7 +267,7 @@ func (s *BaseKeyTestSuite) TestGetBaseKeyInvalidBaseKeyFormat() {
s.Require().True(mockService.called)
s.Require().Error(err)
s.Require().Nil(baseKey)
s.Require().ErrorContains(err, errBaseKeyInvalidFormat.Error())
s.Require().ErrorIs(err, ErrBaseKeyInvalidFormat)
}

func (s *BaseKeyTestSuite) TestGetBaseKeyEmptyBaseKey() {
Expand All @@ -286,7 +286,7 @@ func (s *BaseKeyTestSuite) TestGetBaseKeyEmptyBaseKey() {
s.Require().True(mockService.called)
s.Require().Error(err)
s.Require().Nil(baseKey)
s.Require().ErrorContains(err, errBaseKeyEmpty.Error())
s.Require().ErrorIs(err, ErrBaseKeyEmpty)
}

func (s *BaseKeyTestSuite) TestGetBaseKeyMissingPublicKey() {
Expand All @@ -308,7 +308,7 @@ func (s *BaseKeyTestSuite) TestGetBaseKeyMissingPublicKey() {
s.Require().True(mockService.called)
s.Require().Error(err)
s.Require().Nil(baseKey)
s.Require().ErrorContains(err, errBaseKeyInvalidFormat.Error())
s.Require().ErrorIs(err, ErrBaseKeyInvalidFormat)
}

func (s *BaseKeyTestSuite) TestGetBaseKeyInvalidPublicKey() {
Expand All @@ -330,5 +330,5 @@ func (s *BaseKeyTestSuite) TestGetBaseKeyInvalidPublicKey() {
s.Require().True(mockService.called)
s.Require().Error(err)
s.Require().Nil(baseKey)
s.Require().ErrorContains(err, errBaseKeyInvalidFormat.Error())
s.Require().ErrorIs(err, ErrBaseKeyInvalidFormat)
}
10 changes: 5 additions & 5 deletions sdk/basekeyerrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package sdk
import "errors"

var (
errBaseKeyNotFound = errors.New("base key not found in well-known configuration")
errBaseKeyInvalidFormat = errors.New("base key has invalid format")
errBaseKeyEmpty = errors.New("base key is empty or not provided")
errMarshalBaseKeyFailed = errors.New("failed to marshal base key configuration")
errUnmarshalBaseKeyFailed = errors.New("failed to unmarshal base key configuration")
ErrBaseKeyNotFound = errors.New("base key not found in well-known configuration")
ErrBaseKeyInvalidFormat = errors.New("base key has invalid format")
ErrBaseKeyEmpty = errors.New("base key is empty or not provided")
ErrMarshalBaseKeyFailed = errors.New("failed to marshal base key configuration")
ErrUnmarshalBaseKeyFailed = errors.New("failed to unmarshal base key configuration")
)
28 changes: 20 additions & 8 deletions sdk/granter.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,23 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon
def := pair.GetAttribute()

if def != nil {
storeKeysToCache(def.GetGrants(), keyCache)
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache)
}
v := pair.GetValue()
gType := noKeysFound
if v != nil {
gType = grants.addAllGrants(fqn, v, def)
storeKeysToCache(v.GetGrants(), keyCache)
storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache)
}

// If no more specific grant was found, then add the value grants
if gType == noKeysFound && def != nil {
gType = grants.addAllGrants(fqn, def, def)
storeKeysToCache(def.GetGrants(), keyCache)
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache)
}
if gType == noKeysFound && def.GetNamespace() != nil {
grants.addAllGrants(fqn, def.GetNamespace(), def)
storeKeysToCache(def.GetNamespace().GetGrants(), keyCache)
storeKeysToCache(def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache)
}
}

Expand All @@ -429,7 +429,7 @@ func algProto2String(e policy.KasPublicKeyAlgEnum) string {
return ""
}

func storeKeysToCache(kases []*policy.KeyAccessServer, c *kasKeyCache) {
func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache) {
for _, kas := range kases {
keys := kas.GetPublicKey().GetCached().GetKeys()
if len(keys) == 0 {
Expand All @@ -445,6 +445,18 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, c *kasKeyCache) {
})
}
}
for _, key := range keys {
alg, err := formatAlg(key.GetPublicKey().GetAlgorithm())
if err != nil {
continue
}
c.store(KASInfo{
URL: key.GetKasUri(),
KID: key.GetPublicKey().GetKid(),
Algorithm: alg,
PublicKey: key.GetPublicKey().GetPem(),
})
}
}

// Given a policy (list of data attributes or tags),
Expand Down Expand Up @@ -472,16 +484,16 @@ func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (gr
}

if grants.addAllGrants(fqn, v, def) != noKeysFound {
storeKeysToCache(v.GetGrants(), keyCache)
storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache)
continue
}
// If no more specific grant was found, then add the attr grants
if grants.addAllGrants(fqn, def, def) != noKeysFound {
storeKeysToCache(def.GetGrants(), keyCache)
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache)
continue
}
grants.addAllGrants(fqn, namespace, def)
storeKeysToCache(namespace.GetGrants(), keyCache)
storeKeysToCache(namespace.GetGrants(), namespace.GetKasKeys(), keyCache)
}

return grants, nil
Expand Down
23 changes: 14 additions & 9 deletions sdk/granter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
specifiedKas = "https://attr.kas.com/"
evenMoreSpecificKas = "https://value.kas.com/"
lessSpecificKas = "https://namespace.kas.com/"
fakePem = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQ...\n-----END PUBLIC KEY-----\n"
fakePem = mockRSAPublicKey1
)

var (
Expand Down Expand Up @@ -74,6 +74,8 @@ var (
MP, _ = NewAttributeNameFQN("https://virtru.com/attr/mapped")
mpa, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/a")
mpb, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/b")
mpc, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/c")
mpd, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/d")
mpu, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/unspecified")
)

Expand Down Expand Up @@ -248,7 +250,7 @@ func mockSimpleKasKey(kas, kid string) *policy.SimpleKasKey {
}
var alg policy.Algorithm
switch kid {
case "r1":
case "r0", "r1", "r3":
alg = policy.Algorithm_ALGORITHM_RSA_2048
case "r2":
alg = policy.Algorithm_ALGORITHM_RSA_4096
Expand All @@ -262,7 +264,7 @@ func mockSimpleKasKey(kas, kid string) *policy.SimpleKasKey {
PublicKey: &policy.SimpleKasPublicKey{
Algorithm: alg,
Kid: kid,
Pem: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQ...\n-----END PUBLIC KEY-----\n",
Pem: fakePem,
},
}
}
Expand Down Expand Up @@ -339,17 +341,20 @@ func mockValueFor(fqn AttributeValueFQN) *policy.Value {
switch strings.ToLower(fqn.Value()) {
case "a":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.Grants = make([]*policy.KeyAccessServer, 1)
p.KasKeys[0] = mockSimpleKasKey(evenMoreSpecificKas, "r2")
p.Grants[0] = mockGrant(evenMoreSpecificKas, "r2")
p.Grants[0].PublicKey = createPublicKey("r2", fakePem, policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048)

case "b":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.Grants = make([]*policy.KeyAccessServer, 1)
p.KasKeys[0] = mockSimpleKasKey(evenMoreSpecificKas, "e1")
p.Grants[0] = mockGrant(evenMoreSpecificKas, "e1")
p.Grants[0].PublicKey = createPublicKey("e1", fakePem, policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048)

case "c":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.KasKeys[0] = mockSimpleKasKey(evenMoreSpecificKas, "r0")

case "d":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.KasKeys[0] = mockSimpleKasKey(evenMoreSpecificKas, "r3")

case "unspecified":
// defaults only
default:
Expand Down
21 changes: 15 additions & 6 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ func (k *KASClient) getRewrapRequest(reqs []*kas.UnsignedRewrapRequest_WithPolic
}

type kasKeyRequest struct {
url, algorithm string
url, algorithm, kid string
}

type timeStampedKASInfo struct {
Expand All @@ -399,10 +399,19 @@ func (c *kasKeyCache) clear() {
c.c = make(map[kasKeyRequest]timeStampedKASInfo)
}

func (c *kasKeyCache) get(url, algorithm string) *KASInfo {
cacheKey := kasKeyRequest{url, algorithm}
func (c *kasKeyCache) get(url, algorithm, kid string) *KASInfo {
cacheKey := kasKeyRequest{url, algorithm, kid}
now := time.Now()
cv, ok := c.c[cacheKey]
if !ok && kid == "" {
for k, v := range c.c {
if k.url == url && k.algorithm == algorithm {
cv = v
ok = true
break
}
}
}
if !ok {
return nil
}
Expand All @@ -415,13 +424,13 @@ func (c *kasKeyCache) get(url, algorithm string) *KASInfo {
}

func (c *kasKeyCache) store(ki KASInfo) {
cacheKey := kasKeyRequest{ki.URL, ki.Algorithm}
cacheKey := kasKeyRequest{ki.URL, ki.Algorithm, ki.KID}
c.c[cacheKey] = timeStampedKASInfo{ki, time.Now()}
}

func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm string) (*KASInfo, error) {
func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm, kidToFind string) (*KASInfo, error) {
if s.kasKeyCache != nil {
if cachedValue := s.kasKeyCache.get(kasurl, algorithm); nil != cachedValue {
if cachedValue := s.kasKeyCache.get(kasurl, algorithm, kidToFind); nil != cachedValue {
return cachedValue, nil
}
}
Expand Down
Loading
Loading