diff --git a/service/go.mod b/service/go.mod index d6fff68233..550fe7fc4e 100644 --- a/service/go.mod +++ b/service/go.mod @@ -175,7 +175,7 @@ require ( golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.26.0 // indirect + golang.org/x/text v0.26.0 google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect sigs.k8s.io/yaml v1.4.0 // indirect diff --git a/service/integration/keymanagement_test.go b/service/integration/keymanagement_test.go index 9118935030..b2bad5ccd7 100644 --- a/service/integration/keymanagement_test.go +++ b/service/integration/keymanagement_test.go @@ -3,6 +3,7 @@ package integration import ( "context" "log/slog" + "strings" "testing" "github.com/google/uuid" @@ -12,6 +13,8 @@ import ( "github.com/opentdf/platform/service/internal/fixtures" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/suite" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) var ( @@ -95,6 +98,24 @@ func (s *KeyManagementSuite) Test_CreateProviderConfig_DuplicateName_Fails() { s.Nil(pc) } +func (s *KeyManagementSuite) Test_CreateProviderConfig_CapitalizedName_Succeeds() { + pcIDs := make([]string, 0) + defer func() { + s.deleteTestProviderConfigs(pcIDs) + }() + providerName := strings.ToUpper(testProvider) + pc := s.createTestProviderConfig(providerName, validProviderConfig, nil) + pcIDs = append(pcIDs, pc.GetId()) + + pcGet, err := s.db.PolicyClient.GetProviderConfig(s.ctx, &keymanagement.GetProviderConfigRequest_Name{ + Name: testProvider, + }) + s.Require().NoError(err) + s.NotNil(pcGet) + s.Equal(testProvider, pcGet.GetName()) // Expect name to be lowercased + s.Equal(validProviderConfig, pcGet.GetConfigJson()) +} + func (s *KeyManagementSuite) Test_GetProviderConfig_WithId_Succeeds() { pcIDs := make([]string, 0) defer func() { @@ -119,10 +140,30 @@ func (s *KeyManagementSuite) Test_GetProviderConfig_WithName_Succeeds() { pcIDs = append(pcIDs, pc.GetId()) pc, err := s.db.PolicyClient.GetProviderConfig(s.ctx, &keymanagement.GetProviderConfigRequest_Name{ - Name: pc.GetName(), + Name: testProvider, }) s.Require().NoError(err) s.NotNil(pc) + s.Equal(testProvider, pc.GetName()) + s.Equal(validProviderConfig, pc.GetConfigJson()) +} + +func (s *KeyManagementSuite) Test_GetProviderConfig_MixedCaseName_Succeeds() { + pcIDs := make([]string, 0) + defer func() { + s.deleteTestProviderConfigs(pcIDs) + }() + mixedCaseName := cases.Title(language.English).String(testProvider) // "Test-provider" + pc := s.createTestProviderConfig(mixedCaseName, validProviderConfig, nil) + pcIDs = append(pcIDs, pc.GetId()) + + pcGet, err := s.db.PolicyClient.GetProviderConfig(s.ctx, &keymanagement.GetProviderConfigRequest_Name{ + Name: testProvider, // search with lowercase name + }) + s.Require().NoError(err) + s.NotNil(pcGet) + s.Equal(testProvider, pcGet.GetName()) // Expect name to be lowercased + s.Equal(validProviderConfig, pcGet.GetConfigJson()) } func (s *KeyManagementSuite) Test_GetProviderConfig_InvalidIdentifier_Fails() { @@ -195,7 +236,7 @@ func (s *KeyManagementSuite) Test_UpdateProviderConfig_ExtendsMetadata_Succeeds( }) pcIDs = append(pcIDs, pc.GetId()) s.NotNil(pc) - s.Equal(testProvider, pc.GetName()) + s.Equal(strings.ToLower(testProvider), pc.GetName()) s.Equal(validProviderConfig, pc.GetConfigJson()) s.Equal(validLabels, pc.GetMetadata().GetLabels()) @@ -292,6 +333,49 @@ func (s *KeyManagementSuite) Test_UpdateProviderConfig_ConfigNotFound_Fails() { s.Nil(pc) } +func (s *KeyManagementSuite) Test_UpdateProviderConfig_UpdatesConfigJson_And_Name_Succeeds() { + pcIDs := make([]string, 0) + defer func() { + s.deleteTestProviderConfigs(pcIDs) + }() + pc := s.createTestProviderConfig(testProvider, validProviderConfig, nil) + pcIDs = append(pcIDs, pc.GetId()) + s.NotNil(pc) + s.Equal(testProvider, pc.GetName()) + s.Equal(validProviderConfig, pc.GetConfigJson()) + + pc, err := s.db.PolicyClient.UpdateProviderConfig(s.ctx, &keymanagement.UpdateProviderConfigRequest{ + Id: pc.GetId(), + ConfigJson: validProviderConfig2, + Name: testProvider2, + }) + s.Require().NoError(err) + s.NotNil(pc) + s.Equal(testProvider2, pc.GetName()) + s.Equal(validProviderConfig2, pc.GetConfigJson()) +} + +func (s *KeyManagementSuite) Test_UpdateProviderConfig_UpdatesConfigName_Succeeds() { + pcIDs := make([]string, 0) + defer func() { + s.deleteTestProviderConfigs(pcIDs) + }() + pc := s.createTestProviderConfig(testProvider, validProviderConfig, nil) + pcIDs = append(pcIDs, pc.GetId()) + s.NotNil(pc) + s.Equal(testProvider, pc.GetName()) + s.Equal(validProviderConfig, pc.GetConfigJson()) + + pc, err := s.db.PolicyClient.UpdateProviderConfig(s.ctx, &keymanagement.UpdateProviderConfigRequest{ + Id: pc.GetId(), + Name: strings.ToUpper(testProvider2), + }) + s.Require().NoError(err) + s.NotNil(pc) + s.Equal(testProvider2, pc.GetName()) + s.Equal(validProviderConfig, pc.GetConfigJson()) +} + func (s *KeyManagementSuite) Test_DeleteProviderConfig_Succeeds() { pc := s.createTestProviderConfig(testProvider, validProviderConfig, nil) s.NotNil(pc) diff --git a/service/policy/db/key_management.go b/service/policy/db/key_management.go index 74e26cea9d..a5dd5c608c 100644 --- a/service/policy/db/key_management.go +++ b/service/policy/db/key_management.go @@ -2,7 +2,6 @@ package db import ( "context" - "encoding/json" "errors" "fmt" "strings" @@ -55,7 +54,7 @@ func (c PolicyDBClient) GetProviderConfig(ctx context.Context, identifier any) ( } params = getProviderConfigParams{ID: id} case *keymanagement.GetProviderConfigRequest_Name: - name := pgtypeText(i.Name) + name := pgtypeText(strings.ToLower(i.Name)) if !name.Valid { return nil, db.ErrSelectIdentifierInvalid } @@ -70,11 +69,6 @@ func (c PolicyDBClient) GetProviderConfig(ctx context.Context, identifier any) ( return nil, db.WrapIfKnownInvalidQueryErr(err) } - mappedMetadata := make(map[string]any) - if err = json.Unmarshal(pcRow.Metadata, &mappedMetadata); err != nil { - return nil, err - } - metadata := &common.Metadata{} if err = unmarshalMetadata(pcRow.Metadata, metadata); err != nil { return nil, err @@ -142,7 +136,7 @@ func (c PolicyDBClient) UpdateProviderConfig(ctx context.Context, r *keymanageme id := r.GetId() // if extend we need to merge the metadata - metadataJSON, metadata, err := db.MarshalUpdateMetadata(r.GetMetadata(), r.GetMetadataUpdateBehavior(), func() (*common.Metadata, error) { + metadataJSON, _, err := db.MarshalUpdateMetadata(r.GetMetadata(), r.GetMetadataUpdateBehavior(), func() (*common.Metadata, error) { a, err := c.GetProviderConfig(ctx, &keymanagement.GetProviderConfigRequest_Id{ Id: r.GetId(), }) @@ -171,12 +165,9 @@ func (c PolicyDBClient) UpdateProviderConfig(ctx context.Context, r *keymanageme c.logger.Warn("UpdateProviderConfig updated more than one row", "count", count) } - return &policy.KeyProviderConfig{ - Id: id, - Name: name, - ConfigJson: config, - Metadata: metadata, - }, nil + return c.GetProviderConfig(ctx, &keymanagement.GetProviderConfigRequest_Id{ + Id: id, + }) } func (c PolicyDBClient) DeleteProviderConfig(ctx context.Context, id string) (*policy.KeyProviderConfig, error) {