diff --git a/.gitattributes b/.gitattributes index 2669ff174a..2bf7db77c1 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,5 @@ docs/grpc/** linguist-generated=true docs/openapi/** linguist-generated=true +sdk/sdkconnect/** linguist-generated=true service/policy/db/*.sql.go linguist-generated=true -service/policy/db/models.go linguist-generated=true \ No newline at end of file +service/policy/db/models.go linguist-generated=true diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 4a4ace3686..fb568c81ef 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -390,6 +390,8 @@ jobs: - run: cd service && go get github.com/pseudomuto/protoc-gen-doc/cmd/protoc-gen-doc - run: cd service && go install github.com/pseudomuto/protoc-gen-doc/cmd/protoc-gen-doc - run: make proto-generate + - name: generate connect wrappers + run: make connect-wrapper-generate - name: Restore go.mod after installing protoc-gen-doc run: git restore {service,protocol/go}/go.{mod,sum} - name: validate go mod tidy @@ -399,7 +401,7 @@ jobs: git restore go.sum - run: git diff - run: git diff-files --ignore-submodules - - name: Check that make proto-generate has run before PR submission; see above for error details + - name: Check that make proto-generate and connect-wrapper-generate have run before PR submission; see above for error details run: git diff-files --quiet --ignore-submodules ci: diff --git a/Makefile b/Makefile index e7313c6e31..9867407efd 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # make # To run all lint checks: `LINT_OPTIONS= make lint` -.PHONY: all build clean docker-build fix fmt go-lint license lint proto-generate proto-lint sdk/sdk test tidy toolcheck +.PHONY: all build clean connect-wrapper-generate docker-build fix fmt go-lint license lint proto-generate proto-lint sdk/sdk test tidy toolcheck MODS=protocol/go lib/ocrypto lib/fixtures lib/flattening lib/identifier sdk service examples HAND_MODS=lib/ocrypto lib/fixtures lib/flattening lib/identifier sdk service examples @@ -71,6 +71,9 @@ proto-generate: buf generate buf.build/grpc-ecosystem/grpc-gateway -o tmp-gen --template buf.gen.grpc.docs.yaml buf generate buf.build/grpc-ecosystem/grpc-gateway -o tmp-gen --template buf.gen.openapi.docs.yaml +connect-wrapper-generate: + go run ./sdk/internal/codegen + policy-sql-gen: @which sqlc > /dev/null || { echo "sqlc not found, please install it: https://docs.sqlc.dev/en/stable/overview/install.html"; exit 1; } sqlc generate -f service/policy/db/sqlc.yaml @@ -95,7 +98,7 @@ clean: for m in $(MODS); do (cd $$m && go clean) || exit 1; done rm -f opentdf examples/examples -build: proto-generate opentdf sdk/sdk examples/examples +build: proto-generate connect-wrapper-generate opentdf sdk/sdk examples/examples opentdf: $(shell find service) go build -o opentdf -v service/main.go diff --git a/examples/cmd/benchmark.go b/examples/cmd/benchmark.go index 37cd2215e6..106d0b2d23 100644 --- a/examples/cmd/benchmark.go +++ b/examples/cmd/benchmark.go @@ -106,7 +106,11 @@ func runBenchmark(cmd *cobra.Command, _ []string) error { return err } nanoTDFConfig.EnableECDSAPolicyBinding() - err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + } else { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("https://%s/kas", "localhost:8080")) + } if err != nil { return err } @@ -123,15 +127,26 @@ func runBenchmark(cmd *cobra.Command, _ []string) error { // } // } } else { - tdf, err := client.CreateTDF( - out, in, - sdk.WithDataAttributes(dataAttributes...), - sdk.WithKasInformation( + opts := []sdk.TDFOption{sdk.WithDataAttributes(dataAttributes...), sdk.WithAutoconfigure(false)} + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { + opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ - URL: "http://" + "localhost:8080", + URL: "http://localhost:8080", PublicKey: "", }), - sdk.WithAutoconfigure(false)) + ) + } else { + opts = append(opts, sdk.WithKasInformation( + sdk.KASInfo{ + URL: "https://localhost:8080", + PublicKey: "", + }), + ) + } + tdf, err := client.CreateTDF( + out, in, + opts..., + ) if err != nil { return err } diff --git a/examples/cmd/benchmark_bulk.go b/examples/cmd/benchmark_bulk.go index 8fcfffa48b..b20a0c5456 100644 --- a/examples/cmd/benchmark_bulk.go +++ b/examples/cmd/benchmark_bulk.go @@ -61,7 +61,12 @@ func runBenchmarkBulk(cmd *cobra.Command, _ []string) error { return err } nanoTDFConfig.EnableECDSAPolicyBinding() - err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + // if plaintext or platform endpoint is http, set kas url to http, otherwise https + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + } else { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("https://%s/kas", "localhost:8080")) + } if err != nil { return err } @@ -78,15 +83,26 @@ func runBenchmarkBulk(cmd *cobra.Command, _ []string) error { } } } else { - tdf, err := client.CreateTDF( - out, in, - sdk.WithDataAttributes(dataAttributes...), - sdk.WithKasInformation( + opts := []sdk.TDFOption{sdk.WithDataAttributes(dataAttributes...), sdk.WithAutoconfigure(false)} + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { + opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ - URL: "http://" + "localhost:8080", + URL: "http://localhost:8080", PublicKey: "", }), - sdk.WithAutoconfigure(false)) + ) + } else { + opts = append(opts, sdk.WithKasInformation( + sdk.KASInfo{ + URL: "https://localhost:8080", + PublicKey: "", + }), + ) + } + tdf, err := client.CreateTDF( + out, in, + opts..., + ) if err != nil { return err } diff --git a/examples/cmd/examples.go b/examples/cmd/examples.go index cd74421de3..32d7f403a3 100644 --- a/examples/cmd/examples.go +++ b/examples/cmd/examples.go @@ -28,7 +28,7 @@ func init() { log.SetFlags(log.LstdFlags | log.Llongfile) f := ExamplesCmd.PersistentFlags() f.StringVarP(&clientCredentials, "creds", "", "opentdf-sdk:secret", "client id:secret credentials") - f.StringVarP(&platformEndpoint, "platformEndpoint", "e", "http://localhost:8080", "Platform Endpoint") + f.StringVarP(&platformEndpoint, "platformEndpoint", "e", "https://localhost:8080", "Platform Endpoint") f.StringVarP(&tokenEndpoint, "tokenEndpoint", "t", "http://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token", "OAuth token endpoint") f.BoolVar(&storeCollectionHeaders, "storeCollectionHeaders", false, "Store collection headers") f.BoolVar(&insecurePlaintextConn, "insecurePlaintextConn", false, "Use insecure plaintext connection") @@ -39,6 +39,7 @@ func newSDK() (*sdk.SDK, error) { resolver.SetDefaultScheme("passthrough") opts := []sdk.Option{} if insecurePlaintextConn { + platformEndpoint = strings.Replace(platformEndpoint, "https://", "http://", 1) opts = append(opts, sdk.WithInsecurePlaintextConn()) } if insecureSkipVerify { diff --git a/sdk/audit/metadata_adding_interceptor.go b/sdk/audit/metadata_adding_interceptor.go index f042e4f530..b4bc9fb20d 100644 --- a/sdk/audit/metadata_adding_interceptor.go +++ b/sdk/audit/metadata_adding_interceptor.go @@ -3,6 +3,7 @@ package audit import ( "context" + "connectrpc.com/connect" "github.com/google/uuid" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "google.golang.org/grpc" @@ -46,3 +47,33 @@ func MetadataAddingClientInterceptor( return err } + +func MetadataAddingConnectInterceptor() connect.UnaryInterceptorFunc { + return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + // Only apply to outgoing client requests + if !req.Spec().IsClient { + return next(ctx, req) + } + + // Get any existing request ID from context + requestID, ok := ctx.Value(RequestIDContextKey).(uuid.UUID) + if !ok || requestID == uuid.Nil { + requestID = uuid.New() + } + req.Header().Set(string(RequestIDHeaderKey), requestID.String()) + + // Add the request IP to a custom header so it is preserved + if requestIP, okIP := ctx.Value(RequestIPContextKey).(string); okIP { + req.Header().Set(string(RequestIPHeaderKey), requestIP) + } + + // Add the actor ID from the request so it is preserved if we need it + if actorID, okAct := ctx.Value(ActorIDContextKey).(string); okAct { + req.Header().Set(string(ActorIDHeaderKey), actorID) + } + + return next(ctx, req) + } + }) +} diff --git a/sdk/audit/metadata_adding_interceptor_test.go b/sdk/audit/metadata_adding_interceptor_test.go index 8c0b44dda3..25c384cc8c 100644 --- a/sdk/audit/metadata_adding_interceptor_test.go +++ b/sdk/audit/metadata_adding_interceptor_test.go @@ -3,47 +3,54 @@ package audit import ( "context" "net" + "net/http" + "net/http/httptest" "testing" + "connectrpc.com/connect" "github.com/google/uuid" "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) -type FakeAccessServiceServer struct { +type FakeAccessServiceServerConnect struct { requestID uuid.UUID requestIP string actorID string - kas.UnimplementedAccessServiceServer + kasconnect.UnimplementedAccessServiceHandler } -func (f *FakeAccessServiceServer) PublicKey(ctx context.Context, _ *kas.PublicKeyRequest) (*kas.PublicKeyResponse, error) { - if md, ok := metadata.FromIncomingContext(ctx); ok { - requestIDFromMetadata := md.Get(string(RequestIDHeaderKey)) - if len(requestIDFromMetadata) > 0 { - f.requestID, _ = uuid.Parse(requestIDFromMetadata[0]) - } +func (f *FakeAccessServiceServerConnect) PublicKey(_ context.Context, req *connect.Request[kas.PublicKeyRequest]) (*connect.Response[kas.PublicKeyResponse], error) { + requestIDFromHeader := req.Header().Get(string(RequestIDHeaderKey)) + if requestIDFromHeader != "" { + f.requestID, _ = uuid.Parse(requestIDFromHeader) + } - requestIPFromMetadata := md.Get(string(RequestIPHeaderKey)) - if len(requestIPFromMetadata) > 0 { - f.requestIP = requestIPFromMetadata[0] - } + requestIPFromHeader := req.Header().Get(string(RequestIPHeaderKey)) + if requestIPFromHeader != "" { + f.requestIP = requestIPFromHeader + } - actorIDFromMetadata := md.Get(string(ActorIDHeaderKey)) - if len(actorIDFromMetadata) > 0 { - f.actorID = actorIDFromMetadata[0] - } + actorIDFromHeader := req.Header().Get(string(ActorIDHeaderKey)) + if actorIDFromHeader != "" { + f.actorID = actorIDFromHeader } - return &kas.PublicKeyResponse{}, nil + return connect.NewResponse(&kas.PublicKeyResponse{}), nil } func TestAddingAuditMetadataToOutgoingRequest(t *testing.T) { - server := FakeAccessServiceServer{} - client, stop := runServer(&server) - defer stop() + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} + clientConnect, stopC := runConnectServer(&serverConnect) + defer stopC() + clientGrpc, stopG := runServer(&serverGrpc) + defer stopG() contextRequestID := uuid.New() contextActorID := "actorID" @@ -51,38 +58,101 @@ func TestAddingAuditMetadataToOutgoingRequest(t *testing.T) { ctx = context.WithValue(ctx, RequestIDContextKey, contextRequestID) ctx = context.WithValue(ctx, ActorIDContextKey, contextActorID) - _, err := client.PublicKey(ctx, &kas.PublicKeyRequest{}) + _, err := clientConnect.PublicKey(ctx, connect.NewRequest(&kas.PublicKeyRequest{})) + require.NoError(t, err) + _, err = clientGrpc.PublicKey(ctx, &kas.PublicKeyRequest{}) + require.NoError(t, err) + + for _, ids := range []struct { + actorID string + requestID uuid.UUID + }{ + {requestID: serverConnect.requestID, actorID: serverConnect.actorID}, + {requestID: serverGrpc.requestID, actorID: serverGrpc.actorID}, + } { + assert.Equal(t, contextRequestID, ids.requestID, "request ID did not match") + assert.Equal(t, contextActorID, ids.actorID, "actor ID did not match") + } +} + +func TestIsOKWithNoContextValues(t *testing.T) { + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} + clientConnect, stopC := runConnectServer(&serverConnect) + defer stopC() + clientGrpc, stopG := runServer(&serverGrpc) + defer stopG() + + _, err := clientConnect.PublicKey(t.Context(), connect.NewRequest(&kas.PublicKeyRequest{})) if err != nil { t.Fatalf("error making call: %v", err) } - - if server.requestID != contextRequestID { - t.Fatalf("request ID did not match: %v", server.requestID) + _, err = clientGrpc.PublicKey(t.Context(), &kas.PublicKeyRequest{}) + if err != nil { + t.Fatalf("error making call: %v", err) } - if server.actorID != contextActorID { - t.Fatalf("actor ID did not match: %v", server.actorID) + for _, ids := range []struct { + actorID string + requestID uuid.UUID + }{ + {requestID: serverConnect.requestID, actorID: serverConnect.actorID}, + {requestID: serverGrpc.requestID, actorID: serverGrpc.actorID}, + } { + generatedRequestIDConnect, err := uuid.Parse(ids.requestID.String()) + if err != nil || generatedRequestIDConnect == uuid.Nil { + t.Fatalf("did not generate request ID: %v", err) + } + + if ids.actorID != "" { + t.Fatalf("actor ID not defaulted correctly: %v", ids.actorID) + } } } -func TestIsOKWithNoContextValues(t *testing.T) { - server := FakeAccessServiceServer{} - client, stop := runServer(&server) - defer stop() +func runConnectServer(f *FakeAccessServiceServerConnect) (kasconnect.AccessServiceClient, func()) { + mux := http.NewServeMux() + path, handler := kasconnect.NewAccessServiceHandler(f) + mux.Handle(path, handler) - _, err := client.PublicKey(t.Context(), &kas.PublicKeyRequest{}) - if err != nil { - t.Fatalf("error making call: %v", err) - } + server := httptest.NewServer(mux) + + client := kasconnect.NewAccessServiceClient( + server.Client(), + server.URL, + connect.WithInterceptors(MetadataAddingConnectInterceptor()), + ) - generatedRequestID, err := uuid.Parse(server.requestID.String()) - if err != nil || generatedRequestID == uuid.Nil { - t.Fatalf("did not generate request ID: %v", err) + return client, func() { + server.Close() } +} - if server.actorID != "" { - t.Fatalf("actor ID not defaulted correctly: %v", server.actorID) +type FakeAccessServiceServer struct { + requestID uuid.UUID + requestIP string + actorID string + kas.UnimplementedAccessServiceServer +} + +func (f *FakeAccessServiceServer) PublicKey(ctx context.Context, _ *kas.PublicKeyRequest) (*kas.PublicKeyResponse, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + requestIDFromMetadata := md.Get(string(RequestIDHeaderKey)) + if len(requestIDFromMetadata) > 0 { + f.requestID, _ = uuid.Parse(requestIDFromMetadata[0]) + } + + requestIPFromMetadata := md.Get(string(RequestIPHeaderKey)) + if len(requestIPFromMetadata) > 0 { + f.requestIP = requestIPFromMetadata[0] + } + + actorIDFromMetadata := md.Get(string(ActorIDHeaderKey)) + if len(actorIDFromMetadata) > 0 { + f.actorID = actorIDFromMetadata[0] + } } + return &kas.PublicKeyResponse{}, nil } func runServer(f *FakeAccessServiceServer) (kas.AccessServiceClient, func()) { diff --git a/sdk/auth/token_adding_interceptor.go b/sdk/auth/token_adding_interceptor.go index 489f1b8ce4..95f7f0620f 100644 --- a/sdk/auth/token_adding_interceptor.go +++ b/sdk/auth/token_adding_interceptor.go @@ -11,6 +11,7 @@ import ( "net/http" "time" + "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" @@ -82,6 +83,38 @@ func (i TokenAddingInterceptor) AddCredentials( return err } +func (i TokenAddingInterceptor) AddCredentialsConnect() connect.UnaryInterceptorFunc { + return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + return func( + ctx context.Context, + req connect.AnyRequest, + ) (connect.AnyResponse, error) { + accessToken, err := i.tokenSource.AccessToken(ctx, i.httpClient) + if err != nil { + slog.ErrorContext(ctx, "error getting access token", "error", err) + return nil, connect.NewError(connect.CodeUnauthenticated, err) + } + + // Add Authorization header + req.Header().Set("Authorization", fmt.Sprintf("DPoP %s", accessToken)) + + // Add DPoP header if possible + dpopTok, err := i.GetDPoPToken(req.Spec().Procedure, http.MethodPost, string(accessToken)) + if err == nil { + req.Header().Set("DPoP", dpopTok) + } else { + // since we don't have a setting about whether DPoP is in use on the client and this request _could_ succeed if + // they are talking to a server where DPoP is not required we will just let this through. this method is extremely + // unlikely to fail so hopefully this isn't confusing + slog.ErrorContext(ctx, "error getting DPoP token for outgoing request. Request will not have DPoP token", "error", err) + } + + // Proceed with the RPC + return next(ctx, req) + } + }) +} + func (i TokenAddingInterceptor) GetDPoPToken(path, method, accessToken string) (string, error) { tok, err := i.tokenSource.MakeToken(func(key jwk.Key) ([]byte, error) { jtiBytes := make([]byte, JTILength) diff --git a/sdk/auth/token_adding_interceptor_test.go b/sdk/auth/token_adding_interceptor_test.go index 8d99eed54f..a27802e5b8 100644 --- a/sdk/auth/token_adding_interceptor_test.go +++ b/sdk/auth/token_adding_interceptor_test.go @@ -11,13 +11,16 @@ import ( "errors" "net" "net/http" + "net/http/httptest" "testing" + "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/sdk/httputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,73 +45,109 @@ func TestAddingTokensToOutgoingRequest(t *testing.T) { key: key, accessToken: "thisisafakeaccesstoken", } - server := FakeAccessServiceServer{} + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} oo := NewTokenAddingInterceptorWithClient(&ts, httputil.SafeHTTPClientWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, })) - client, stop := runServer(&server, oo) - defer stop() + clientConnect, stopC := runConnectServer(&serverConnect, oo) + defer stopC() - _, err = client.PublicKey(t.Context(), &kas.PublicKeyRequest{}) - require.NoError(t, err, "error making call") - - assert.ElementsMatch(t, server.accessToken, []string{"DPoP thisisafakeaccesstoken"}) - require.Len(t, server.dpopToken, 1, "incorrect dpop token headers") - - dpopToken := server.dpopToken[0] - alg, ok := key.Algorithm().(jwa.SignatureAlgorithm) - assert.True(t, ok, "got a bad signing algorithm") - - _, err = jws.Verify([]byte(dpopToken), jws.WithKey(alg, key)) - require.NoError(t, err, "error verifying signature") - - parsedSignature, _ := jws.Parse([]byte(dpopToken)) - require.Len(t, parsedSignature.Signatures(), 1, "incorrect number of signatures") - - sig := parsedSignature.Signatures()[0] - tokenKey, ok := sig.ProtectedHeaders().Get("jwk") - require.True(t, ok, "didn't get jwk token key") - tkkey, ok := tokenKey.(jwk.Key) - require.True(t, ok, "wrong type for jwk token key", tokenKey) - - tp, _ := tkkey.Thumbprint(crypto.SHA256) - ktp, _ := key.Thumbprint(crypto.SHA256) - assert.Equal(t, tp, ktp, "got the wrong key from the token") - - parsedToken, _ := jwt.Parse([]byte(dpopToken), jwt.WithVerify(false)) + clientGrpc, stopG := runServer(&serverGrpc, oo) + defer stopG() - method, ok := parsedToken.Get("htm") - require.True(t, ok, "error getting htm claim") - assert.Equal(t, http.MethodPost, method, "got a bad method") - - path, ok := parsedToken.Get("htu") - require.True(t, ok, "error getting htu claim") - assert.Equal(t, "/kas.AccessService/PublicKey", path, "got a bad path") - - h := sha256.New() - h.Write([]byte("thisisafakeaccesstoken")) - expectedHash := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h.Sum(nil)) + _, err = clientConnect.PublicKey(t.Context(), connect.NewRequest(&kas.PublicKeyRequest{})) + require.NoError(t, err, "error making call") + _, err = clientGrpc.PublicKey(t.Context(), &kas.PublicKeyRequest{}) + require.NoError(t, err, "error making call") - ath, ok := parsedToken.Get("ath") - require.True(t, ok, "error getting ath claim") - assert.Equal(t, expectedHash, ath, "invalid ath claim in token") + for _, server := range []struct { + accessToken []string + dpopToken []string + }{ + {accessToken: serverConnect.accessToken, dpopToken: serverConnect.dpopToken}, + {accessToken: serverGrpc.accessToken, dpopToken: serverGrpc.dpopToken}, + } { + assert.ElementsMatch(t, server.accessToken, []string{"DPoP thisisafakeaccesstoken"}) + require.Len(t, server.dpopToken, 1, "incorrect dpop token headers") + alg, ok := key.Algorithm().(jwa.SignatureAlgorithm) + assert.True(t, ok, "got a bad signing algorithm") + + dpopToken := server.dpopToken[0] + _, err = jws.Verify([]byte(dpopToken), jws.WithKey(alg, key)) + require.NoError(t, err, "error verifying signature") + + parsedSignature, _ := jws.Parse([]byte(dpopToken)) + require.Len(t, parsedSignature.Signatures(), 1, "incorrect number of signatures") + + sig := parsedSignature.Signatures()[0] + tokenKey, ok := sig.ProtectedHeaders().Get("jwk") + require.True(t, ok, "didn't get jwk token key") + tkkey, ok := tokenKey.(jwk.Key) + require.True(t, ok, "wrong type for jwk token key", tokenKey) + + tp, _ := tkkey.Thumbprint(crypto.SHA256) + ktp, _ := key.Thumbprint(crypto.SHA256) + assert.Equal(t, tp, ktp, "got the wrong key from the token") + + parsedToken, _ := jwt.Parse([]byte(dpopToken), jwt.WithVerify(false)) + + method, ok := parsedToken.Get("htm") + require.True(t, ok, "error getting htm claim") + assert.Equal(t, http.MethodPost, method, "got a bad method") + + path, ok := parsedToken.Get("htu") + require.True(t, ok, "error getting htu claim") + assert.Equal(t, "/kas.AccessService/PublicKey", path, "got a bad path") + + h := sha256.New() + h.Write([]byte("thisisafakeaccesstoken")) + expectedHash := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h.Sum(nil)) + + ath, ok := parsedToken.Get("ath") + require.True(t, ok, "error getting ath claim") + assert.Equal(t, expectedHash, ath, "invalid ath claim in token") + } } func Test_InvalidCredentials_DoesNotSendMessage(t *testing.T) { ts := FakeTokenSource{key: nil, accessToken: ""} - server := FakeAccessServiceServer{} + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} oo := NewTokenAddingInterceptorWithClient(&ts, httputil.SafeHTTPClientWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, })) - client, stop := runServer(&server, oo) - defer stop() + clientConnect, stopC := runConnectServer(&serverConnect, oo) + defer stopC() + clientGrpc, stopG := runServer(&serverGrpc, oo) + defer stopG() - _, err := client.PublicKey(t.Context(), &kas.PublicKeyRequest{}) + _, err := clientConnect.PublicKey(t.Context(), connect.NewRequest(&kas.PublicKeyRequest{})) + require.Error(t, err, "should not have sent message because the token source returned an error") + _, err = clientGrpc.PublicKey(t.Context(), &kas.PublicKeyRequest{}) require.Error(t, err, "should not have sent message because the token source returned an error") } +type FakeAccessServiceServerConnect struct { + accessToken []string + dpopToken []string + dpopKey jwk.Key + kasconnect.UnimplementedAccessServiceHandler +} + +func (f *FakeAccessServiceServerConnect) PublicKey(ctx context.Context, req *connect.Request[kas.PublicKeyRequest]) (*connect.Response[kas.PublicKeyResponse], error) { + f.accessToken = []string{req.Header().Get("authorization")} + f.dpopToken = []string{req.Header().Get("dpop")} + var ok bool + f.dpopKey, ok = ctx.Value("dpop-jwk").(jwk.Key) + if !ok { + f.dpopKey = nil + } + return connect.NewResponse(&kas.PublicKeyResponse{}), nil +} + type FakeAccessServiceServer struct { accessToken []string dpopToken []string @@ -156,6 +195,26 @@ func (fts *FakeTokenSource) MakeToken(f func(jwk.Key) ([]byte, error)) ([]byte, return f(fts.key) } +func runConnectServer( + f *FakeAccessServiceServerConnect, oo TokenAddingInterceptor, +) (kasconnect.AccessServiceClient, func()) { + mux := http.NewServeMux() + path, handler := kasconnect.NewAccessServiceHandler(f) + mux.Handle(path, handler) + + server := httptest.NewServer(mux) + + client := kasconnect.NewAccessServiceClient( + server.Client(), + server.URL, + connect.WithInterceptors(oo.AddCredentialsConnect()), + ) + + return client, func() { + server.Close() + } +} + func runServer( //nolint:ireturn // this is pretty concrete f *FakeAccessServiceServer, oo TokenAddingInterceptor, ) (kas.AccessServiceClient, func()) { diff --git a/sdk/bulk.go b/sdk/bulk.go index e8080dbb0f..20e1421d47 100644 --- a/sdk/bulk.go +++ b/sdk/bulk.go @@ -167,7 +167,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { } } - kasClient := newKASClient(s.dialOptions, s.tokenSource, s.kasSessionKey) + kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey) allRewrapResp := make(map[string][]kaoResult) var err error for kasurl, rewrapRequests := range kasRewrapRequests { diff --git a/sdk/go.mod b/sdk/go.mod index 0949d034d3..0addc797cd 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -5,6 +5,7 @@ go 1.24.0 toolchain go1.24.2 require ( + connectrpc.com/connect v1.18.1 github.com/Masterminds/semver/v3 v3.3.1 github.com/google/uuid v1.6.0 github.com/gowebpki/jcs v1.0.1 @@ -17,6 +18,7 @@ require ( github.com/testcontainers/testcontainers-go v0.34.0 github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/oauth2 v0.30.0 + golang.org/x/tools v0.33.0 google.golang.org/grpc v1.72.1 google.golang.org/protobuf v1.36.6 ) @@ -75,7 +77,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect - github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect @@ -85,10 +87,12 @@ require ( go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - golang.org/x/crypto v0.36.0 // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/sdk/go.sum b/sdk/go.sum index 42ceb622b0..42d229e54b 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -1,5 +1,7 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 h1:LEXWFH/xZ5oOWrC3oOtHbUyBdzRWMCPpAQmKC9v05mA= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1/go.mod h1:XF+P8+RmfdufmIYpGUC+6bF7S+IlmHDEnCrO3OXaUAQ= +connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= +connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= @@ -160,9 +162,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= -github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= -github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= @@ -198,12 +199,14 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -214,8 +217,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -223,6 +226,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -240,24 +245,24 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= @@ -267,6 +272,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/sdk/granter.go b/sdk/granter.go index 5e0f32c5ab..b45c55b640 100644 --- a/sdk/granter.go +++ b/sdk/granter.go @@ -12,6 +12,7 @@ import ( "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/sdk/sdkconnect" ) var ErrInvalid = errors.New("invalid type") @@ -221,7 +222,7 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant { } // Gets a list of directory of KAS grants for a list of attribute FQNs -func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attributes.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { +func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { fqnsStr := make([]string, len(fqns)) for i, v := range fqns { fqnsStr[i] = v.String() diff --git a/sdk/granter_test.go b/sdk/granter_test.go index f800c6aa9c..e90b146055 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -11,9 +11,9 @@ import ( "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/sdk/sdkconnect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" ) const ( @@ -503,14 +503,14 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) { var listAttributeResp attributes.ListAttributesResponse type mockAttributesClient struct { - attributes.AttributesServiceClient + sdkconnect.AttributesServiceClient } -func (*mockAttributesClient) ListAttributes(_ context.Context, _ *attributes.ListAttributesRequest, _ ...grpc.CallOption) (*attributes.ListAttributesResponse, error) { +func (*mockAttributesClient) ListAttributes(_ context.Context, _ *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) { return &listAttributeResp, nil } -func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *attributes.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValuesByFqnsResponse, error) { +func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *attributes.GetAttributeValuesByFqnsRequest) (*attributes.GetAttributeValuesByFqnsResponse, error) { av := make(map[string]*attributes.GetAttributeValuesByFqnsResponse_AttributeAndValue) for _, v := range req.GetFqns() { vfqn, err := NewAttributeValueFQN(v) diff --git a/sdk/internal/codegen/README.md b/sdk/internal/codegen/README.md new file mode 100644 index 0000000000..c96c5d961a --- /dev/null +++ b/sdk/internal/codegen/README.md @@ -0,0 +1,27 @@ +# SDK Internal Codegen + +## Overview +This folder contains the code generation logic for the SDK's internal components. It automates the creation of ConnectRPC wrapper clients, to ensure consistency and reduce manual effort. These clients have similar interfaces to the GRPC proto generated clients allowing for ease of transition to ConnectRPC client-side. + +--- + +## What It Generates +The code generation in this folder focuses on: +1. ConnectRPC wrapper clients for various platform services +2. Interfaces for each wrapper client + +The clients generated are defined in `clientsToGenerateList` in `runner/generate.go`. + +--- + +## How to Run Code Generation +To generate the internal SDK code: + +```bash +go run ./sdk/internal/codegen +``` + +Or use the provided Makefile command +```bash +make connect-wrapper-generate +``` \ No newline at end of file diff --git a/sdk/internal/codegen/main.go b/sdk/internal/codegen/main.go new file mode 100644 index 0000000000..e8fb00c0e0 --- /dev/null +++ b/sdk/internal/codegen/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "log" + + "github.com/opentdf/platform/sdk/internal/codegen/runner" +) + +func main() { + if err := runner.Generate(); err != nil { + log.Fatal(err) + } +} diff --git a/sdk/internal/codegen/runner/generate.go b/sdk/internal/codegen/runner/generate.go new file mode 100644 index 0000000000..1371a52486 --- /dev/null +++ b/sdk/internal/codegen/runner/generate.go @@ -0,0 +1,254 @@ +package runner + +import ( + "errors" + "fmt" + "go/ast" + "log/slog" + "os" + "path" + "path/filepath" + "runtime" + + "golang.org/x/tools/go/packages" +) + +type clientsToGenerate struct { + grpcClientInterface string + suffix string + packageNameOverride string + grpcPackagePath string +} + +var clientsToGenerateList = []clientsToGenerate{ + { + grpcClientInterface: "ActionServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/actions", + }, + { + grpcClientInterface: "AttributesServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/attributes", + }, + { + grpcClientInterface: "AuthorizationServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/authorization", + }, + { + grpcClientInterface: "AuthorizationServiceClient", + suffix: "V2", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/authorization/v2", + packageNameOverride: "authorizationv2", + }, + { + grpcClientInterface: "EntityResolutionServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/entityresolution", + }, + { + grpcClientInterface: "EntityResolutionServiceClient", + suffix: "V2", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/entityresolution/v2", + packageNameOverride: "entityresolutionv2", + }, + { + grpcClientInterface: "KeyAccessServerRegistryServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/kasregistry", + }, + { + grpcClientInterface: "KeyManagementServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/keymanagement", + }, + { + grpcClientInterface: "NamespaceServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/namespaces", + }, + { + grpcClientInterface: "RegisteredResourcesServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/registeredresources", + }, + { + grpcClientInterface: "ResourceMappingServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/resourcemapping", + }, + { + grpcClientInterface: "SubjectMappingServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/subjectmapping", + }, + { + grpcClientInterface: "UnsafeServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/unsafe", + }, + { + grpcClientInterface: "WellKnownServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/wellknownconfiguration", + }, +} + +func Generate() error { + for _, client := range clientsToGenerateList { + slog.Info("Generating wrapper for", "interface", client.grpcClientInterface, "package", client.grpcPackagePath) + // Load the Go package using the import path + cfg := &packages.Config{ + Mode: packages.NeedName | + packages.NeedTypes | + packages.NeedTypesInfo | + packages.NeedSyntax | + packages.NeedCompiledGoFiles, + } + pkgs, err := packages.Load(cfg, client.grpcPackagePath) + if err != nil { + return fmt.Errorf("failed to load package %s: %w", client.grpcPackagePath, err) + } + if packages.PrintErrors(pkgs) > 0 { + return fmt.Errorf("errors loading package %s", client.grpcPackagePath) + } + found := false + err = nil + // Loop through the package and its files + for _, p := range pkgs { + for _, file := range p.Syntax { + ast.Inspect(file, func(n ast.Node) bool { + if found { + return false // skip rest of traversal + } + ts, ok := n.(*ast.TypeSpec) + if !ok { + return true + } + iface, ok := ts.Type.(*ast.InterfaceType) + if !ok { + return true + } + if ts.Name.Name == client.grpcClientInterface { + packageName := path.Base(client.grpcPackagePath) + if client.packageNameOverride != "" { + packageName = client.packageNameOverride + } + code := generateWrapper(ts.Name.Name, iface, client.grpcPackagePath, packageName, client.suffix) + var currentDir string + currentDir, err = getCurrentFileDir() + outputPath := filepath.Join(currentDir, "..", "..", "..", "sdkconnect", packageName+".go") + err = os.WriteFile(outputPath, []byte(code), 0o644) //nolint:gosec // ignore G306 + if err != nil { + slog.Error("Error writing file", "file", outputPath, "error", err) + } + found = true + return false // stop traversal + } + return true + }) + if found { + break + } + } + if found { + break + } + } + if !found { + return fmt.Errorf("interface %q not found in package %s", client.grpcClientInterface, client.grpcPackagePath) + } + if err != nil { + return fmt.Errorf("error writing file: %w", err) + } + } + return nil +} + +func getCurrentFileDir() (string, error) { + _, filename, _, ok := runtime.Caller(0) + if !ok { + return "", errors.New("could not get caller file (generate.go) working directory") + } + return filepath.Dir(filename), nil +} + +// Helper function to get the method names of an interface +func getMethodNames(interfaceType *ast.InterfaceType) []string { + methodNames := []string{} + for _, method := range interfaceType.Methods.List { + if len(method.Names) > 0 { + methodNames = append(methodNames, method.Names[0].Name) + } + } + return methodNames +} + +// Generate wrapper code for the Connect RPC client interface +func generateWrapper(interfaceName string, interfaceType *ast.InterfaceType, packagePath string, packageName string, suffix string) string { + // Get method names dynamically from the interface + methods := getMethodNames(interfaceType) + connectPackageName := packageName + "connect" + + // Start generating the wrapper code + wrapperCode := fmt.Sprintf(`// Wrapper for %s (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "%s" + "%s" +) + +type %s%sConnectWrapper struct { + %s.%s +} + +func New%s%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *%s%sConnectWrapper { + return &%s%sConnectWrapper{%s: %s.New%s(httpClient, baseURL, opts...)} +} +`, + interfaceName, + packagePath, + packagePath+"/"+connectPackageName, + interfaceName, + suffix, + connectPackageName, + interfaceName, + interfaceName, + suffix, + interfaceName, + suffix, + interfaceName, + suffix, + interfaceName, + connectPackageName, + interfaceName) + + // Generate the interface type definition + wrapperCode += generateInterfaceType(interfaceName, methods, packageName, suffix) + // Now generate a wrapper function for each method in the interface + for _, method := range methods { + wrapperCode += generateWrapperMethod(interfaceName, method, packageName, suffix) + } + + // Output the generated wrapper code + return wrapperCode +} + +func generateInterfaceType(interfaceName string, methods []string, packageName string, suffix string) string { + // Generate the interface type definition + interfaceType := fmt.Sprintf(` +type %s%s interface { +`, interfaceName, suffix) + for _, method := range methods { + interfaceType += fmt.Sprintf(` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) +`, method, packageName, method, packageName, method) + } + interfaceType += "}\n" + return interfaceType +} + +// Generate the wrapper method for a specific method in the interface +func generateWrapperMethod(interfaceName, methodName, packageName string, suffix string) string { + return fmt.Sprintf(` +func (w *%s%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) { + // Wrap Connect RPC client request + res, err := w.%s.%s(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} +`, interfaceName, suffix, methodName, packageName, methodName, packageName, methodName, interfaceName, methodName) +} diff --git a/sdk/kas_client.go b/sdk/kas_client.go index ae6dcde964..93bfdf1a27 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -6,17 +6,19 @@ import ( "errors" "fmt" "net" + "net/http" "net/url" "time" + "connectrpc.com/connect" "google.golang.org/protobuf/encoding/protojson" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/sdk/auth" - "google.golang.org/grpc" ) const ( @@ -27,7 +29,8 @@ const ( type KASClient struct { accessTokenSource auth.AccessTokenSource - dialOptions []grpc.DialOption + httpClient *http.Client + connectOptions []connect.ClientOption sessionKey ocrypto.KeyPair // Set this to enable legacy, non-batch rewrap requests @@ -45,10 +48,11 @@ type decryptor interface { Decrypt(ctx context.Context, results []kaoResult) (int, error) } -func newKASClient(dialOptions []grpc.DialOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair) *KASClient { +func newKASClient(httpClient *http.Client, options []connect.ClientOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair) *KASClient { return &KASClient{ accessTokenSource: accessTokenSource, - dialOptions: dialOptions, + httpClient: httpClient, + connectOptions: options, sessionKey: sessionKey, supportSingleRewrapEndpoint: true, } @@ -60,27 +64,22 @@ func (k *KASClient) makeRewrapRequest(ctx context.Context, requests []*kas.Unsig if err != nil { return nil, err } - grpcAddress, err := getGRPCAddress(requests[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetKasUrl()) + kasURL := requests[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetKasUrl() + parsedURL, err := parseBaseURL(kasURL) if err != nil { - return nil, err - } - - conn, err := grpc.NewClient(grpcAddress, k.dialOptions...) - if err != nil { - return nil, fmt.Errorf("error connecting to sas: %w", err) + return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasURL, err) } - defer conn.Close() - serviceClient := kas.NewAccessServiceClient(conn) + serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, parsedURL, k.connectOptions...) - response, err := serviceClient.Rewrap(ctx, rewrapRequest) + response, err := serviceClient.Rewrap(ctx, connect.NewRequest(rewrapRequest)) if err != nil { return upgradeRewrapErrorV1(err, requests) } - upgradeRewrapResponseV1(response, requests) + upgradeRewrapResponseV1(response.Msg, requests) - return response, nil + return response.Msg, nil } // convert v1 responses to v2 @@ -313,24 +312,22 @@ func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecrypt return policyResults, nil } -func getGRPCAddress(kasURL string) (string, error) { - parsedURL, err := url.Parse(kasURL) +func parseBaseURL(rawURL string) (string, error) { + u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("cannot parse kas url(%s): %w", kasURL, err) + return "", err } - // Needed to support buffconn for testing - if parsedURL.Host == "" && parsedURL.Port() == "" { - return "", nil - } + host := u.Hostname() + port := u.Port() - port := parsedURL.Port() - // if port is empty, default to 443. - if port == "" { - port = "443" + // Add port only if it's present + addr := host + if port != "" { + addr = net.JoinHostPort(host, port) } - return net.JoinHostPort(parsedURL.Hostname(), port), nil + return fmt.Sprintf("%s://%s", u.Scheme, addr), nil } func (k *KASClient) getRewrapRequest(reqs []*kas.UnsignedRewrapRequest_WithPolicyRequest, pubKey string) (*kas.RewrapRequest, error) { @@ -422,23 +419,18 @@ func (c *kasKeyCache) store(ki KASInfo) { c.c[cacheKey] = timeStampedKASInfo{ki, time.Now()} } -func (s SDK) getPublicKey(ctx context.Context, url, algorithm string) (*KASInfo, error) { +func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm string) (*KASInfo, error) { if s.kasKeyCache != nil { - if cachedValue := s.kasKeyCache.get(url, algorithm); nil != cachedValue { + if cachedValue := s.kasKeyCache.get(kasurl, algorithm); nil != cachedValue { return cachedValue, nil } } - grpcAddress, err := getGRPCAddress(url) - if err != nil { - return nil, err - } - conn, err := grpc.NewClient(grpcAddress, s.dialOptions...) + parsedURL, err := parseBaseURL(kasurl) if err != nil { - return nil, fmt.Errorf("error connecting to grpc service at %s: %w", url, err) + return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasurl, err) } - defer conn.Close() - serviceClient := kas.NewAccessServiceClient(conn) + serviceClient := kasconnect.NewAccessServiceClient(s.conn.Client, parsedURL, s.conn.Options...) req := kas.PublicKeyRequest{ Algorithm: algorithm, @@ -446,21 +438,21 @@ func (s SDK) getPublicKey(ctx context.Context, url, algorithm string) (*KASInfo, if s.config.tdfFeatures.noKID { req.V = "1" } - resp, err := serviceClient.PublicKey(ctx, &req) + resp, err := serviceClient.PublicKey(ctx, connect.NewRequest(&req)) if err != nil { return nil, fmt.Errorf("error making request to KAS: %w", err) } - kid := resp.GetKid() + kid := resp.Msg.GetKid() if s.config.tdfFeatures.noKID { kid = "" } ki := KASInfo{ - URL: url, + URL: kasurl, Algorithm: algorithm, KID: kid, - PublicKey: resp.GetPublicKey(), + PublicKey: resp.Msg.GetPublicKey(), } if s.kasKeyCache != nil { s.kasKeyCache.store(ki) diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 496c195f5d..d15d785755 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -5,6 +5,7 @@ import ( "net/http" "testing" + "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" @@ -15,7 +16,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "google.golang.org/grpc" "google.golang.org/protobuf/encoding/protojson" ) @@ -58,12 +58,12 @@ func getTokenSource(t *testing.T) FakeAccessTokenSource { } func TestCreatingRequest(t *testing.T) { - var dialOption []grpc.DialOption + var options []connect.ClientOption tokenSource := getTokenSource(t) kasKey, err := ocrypto.NewRSAKeyPair(tdf3KeySize) require.NoError(t, err, "error creating RSA Key") - client := newKASClient(dialOption, tokenSource, &kasKey) + client := newKASClient(nil, options, tokenSource, &kasKey) require.NoError(t, err) keyAccess := []*kaspb.UnsignedRewrapRequest_WithPolicyRequest{ @@ -125,7 +125,7 @@ func TestCreatingRequest(t *testing.T) { } func Test_StoreKASKeys(t *testing.T) { - s, err := New("localhost:8080", + s, err := New("http://localhost:8080", WithPlatformConfiguration(PlatformConfiguration{ "idp": map[string]interface{}{ "issuer": "https://example.org", @@ -214,3 +214,49 @@ func (suite *TestUpgradeRewrapRequestV1Suite) TestUpgradeRewrapRequestV1_Empty() func TestUpgradeRewrapRequestV1(t *testing.T) { suite.Run(t, new(TestUpgradeRewrapRequestV1Suite)) } + +func TestParseBaseUrl(t *testing.T) { + tests := []struct { + name string + input string + expected string + expectError bool + }{ + { + name: "Valid URL with scheme and port", + input: "https://example.com:8080/path", + expected: "https://example.com:8080", + expectError: false, + }, + { + name: "Valid URL with scheme and no port", + input: "https://example.com/path", + expected: "https://example.com", + expectError: false, + }, + { + name: "Valid URL with default port", + input: "http://example.com", + expected: "http://example.com", + expectError: false, + }, + { + name: "Invalid URL with invalid characters", + input: "https://exa mple.com", + expected: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseBaseURL(tt.input) + if tt.expectError { + require.Error(t, err, "Expected an error for test case: %s", tt.name) + } else { + require.NoError(t, err, "Did not expect an error for test case: %s", tt.name) + assert.Equal(t, tt.expected, result, "Unexpected result for test case: %s", tt.name) + } + }) + } +} diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index f64f1fb87e..910dc3f729 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -1052,7 +1052,7 @@ func (s SDK) getNanoRewrapKey(ctx context.Context, decryptor *NanoTDFDecryptHand } } - client := newKASClient(s.dialOptions, s.tokenSource, nil) + client := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, nil) kasURL, err := decryptor.header.kasURL.GetURL() if err != nil { return nil, fmt.Errorf("nano header kasUrl: %w", err) diff --git a/sdk/nanotdf_test.go b/sdk/nanotdf_test.go index 6cb7990f41..2854548ec3 100644 --- a/sdk/nanotdf_test.go +++ b/sdk/nanotdf_test.go @@ -308,7 +308,7 @@ func TestCreateNanoTDF(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s, err := New("localhost:8080", WithPlatformConfiguration(PlatformConfiguration{})) + s, err := New("http://localhost:8080", WithPlatformConfiguration(PlatformConfiguration{})) require.NoError(t, err) _, err = s.CreateNanoTDF(tt.writer, tt.reader, tt.config) if tt.expectedError != "" { diff --git a/sdk/options.go b/sdk/options.go index 78d4f5ad7b..7ad993388d 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -5,29 +5,32 @@ import ( "crypto/tls" "net/http" + "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/auth/oauth" "github.com/opentdf/platform/sdk/httputil" "golang.org/x/oauth2" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" ) type Option func(*config) +type ConnectRPCConnection struct { + Client *http.Client + Endpoint string + Options []connect.ClientOption +} + // Internal config struct for building SDK options. type config struct { // Platform configuration structure is subject to change. Consume via accessor methods. PlatformConfiguration PlatformConfiguration - dialOption grpc.DialOption + extraClientOptions []connect.ClientOption httpClient *http.Client clientCredentials *oauth.ClientCredentials tokenExchange *oauth.TokenExchangeInfo tokenEndpoint string scopes []string - extraDialOptions []grpc.DialOption certExchange *oauth.CertExchangeInfo kasSessionKey *ocrypto.RsaKeyPair dpopKey *ocrypto.RsaKeyPair @@ -36,8 +39,8 @@ type config struct { nanoFeatures nanoFeatures customAccessTokenSource auth.AccessTokenSource oauthAccessTokenSource oauth2.TokenSource - coreConn *grpc.ClientConn - entityResolutionConn *grpc.ClientConn + coreConn *ConnectRPCConnection + entityResolutionConn *ConnectRPCConnection collectionStore *collectionStore shouldValidatePlatformConnectivity bool } @@ -56,19 +59,13 @@ type nanoFeatures struct { type PlatformConfiguration map[string]interface{} -func (c *config) build() []grpc.DialOption { - return []grpc.DialOption{c.dialOption} -} - // WithInsecureSkipVerifyConn returns an Option that sets up HTTPS connection without verification. func WithInsecureSkipVerifyConn() Option { return func(c *config) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, // #nosec G402 - } - c.dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) - // used by http client + } // used by http client c.httpClient = httputil.SafeHTTPClientWithTLSConfig(tlsConfig) } } @@ -83,7 +80,6 @@ func WithStoreCollectionHeaders() Option { // WithInsecurePlaintextConn returns an Option that sets up HTTP connection sent in the clear. func WithInsecurePlaintextConn() Option { return func(c *config) { - c.dialOption = grpc.WithTransportCredentials(insecure.NewCredentials()) // used by http client // FIXME anything to do here c.httpClient = httputil.SafeHTTPClient() @@ -126,20 +122,20 @@ func WithOAuthAccessTokenSource(t oauth2.TokenSource) Option { } // Deprecated: Use WithCustomCoreConnection instead -func WithCustomPolicyConnection(conn *grpc.ClientConn) Option { +func WithCustomPolicyConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } } // Deprecated: Use WithCustomCoreConnection instead -func WithCustomAuthorizationConnection(conn *grpc.ClientConn) Option { +func WithCustomAuthorizationConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } } -func WithCustomEntityResolutionConnection(conn *grpc.ClientConn) Option { +func WithCustomEntityResolutionConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.entityResolutionConn = conn } @@ -156,12 +152,6 @@ func WithTokenExchange(subjectToken string, audience []string) Option { } } -func WithExtraDialOptions(dialOptions ...grpc.DialOption) Option { - return func(c *config) { - c.extraDialOptions = dialOptions - } -} - // The session key pair is used to encrypt responses from KAS for a given session // and can be reused across an entire session. // Please use with caution. @@ -182,7 +172,7 @@ func WithSessionSignerRSA(key *rsa.PrivateKey) Option { } } -func WithCustomWellknownConnection(conn *grpc.ClientConn) Option { +func WithCustomWellknownConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } @@ -220,12 +210,19 @@ func WithNoKIDInKAO() Option { } // WithCoreConnection returns an Option that sets up a connection to the core platform -func WithCustomCoreConnection(conn *grpc.ClientConn) Option { +func WithCustomCoreConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } } +// WithExtraClientOptions returns an Option that adds extra connect rpc client options to the conect rpc clients +func WithExtraClientOptions(opts ...connect.ClientOption) Option { + return func(c *config) { + c.extraClientOptions = opts + } +} + // WithNoKIDInNano disables storing the KID in the KAS ResourceLocator. // This allows generating NanoTDF files that are compatible with legacy file formats (no KID). func WithNoKIDInNano() Option { diff --git a/sdk/sdk.go b/sdk/sdk.go index 9f435ee218..44f4afde01 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -9,34 +9,21 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/url" "strings" + "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" - "github.com/opentdf/platform/protocol/go/authorization" - authorizationV2 "github.com/opentdf/platform/protocol/go/authorization/v2" - "github.com/opentdf/platform/protocol/go/entityresolution" - entityresolutionV2 "github.com/opentdf/platform/protocol/go/entityresolution/v2" "github.com/opentdf/platform/protocol/go/policy" - "github.com/opentdf/platform/protocol/go/policy/actions" - "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/kasregistry" - "github.com/opentdf/platform/protocol/go/policy/keymanagement" - "github.com/opentdf/platform/protocol/go/policy/namespaces" - "github.com/opentdf/platform/protocol/go/policy/registeredresources" - "github.com/opentdf/platform/protocol/go/policy/resourcemapping" - "github.com/opentdf/platform/protocol/go/policy/subjectmapping" - "github.com/opentdf/platform/protocol/go/policy/unsafe" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" + "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/httputil" "github.com/opentdf/platform/sdk/internal/archive" + "github.com/opentdf/platform/sdk/sdkconnect" "github.com/xeipuuv/gojsonschema" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" healthpb "google.golang.org/grpc/health/grpc_health_v1" ) @@ -66,37 +53,36 @@ type SDK struct { config *kasKeyCache *collectionStore - conn *grpc.ClientConn - dialOptions []grpc.DialOption + conn *ConnectRPCConnection tokenSource auth.AccessTokenSource - Actions actions.ActionServiceClient - Attributes attributes.AttributesServiceClient - Authorization authorization.AuthorizationServiceClient - AuthorizationV2 authorizationV2.AuthorizationServiceClient - EntityResoution entityresolution.EntityResolutionServiceClient - EntityResolutionV2 entityresolutionV2.EntityResolutionServiceClient - KeyAccessServerRegistry kasregistry.KeyAccessServerRegistryServiceClient - Namespaces namespaces.NamespaceServiceClient - RegisteredResources registeredresources.RegisteredResourcesServiceClient - ResourceMapping resourcemapping.ResourceMappingServiceClient - SubjectMapping subjectmapping.SubjectMappingServiceClient - Unsafe unsafe.UnsafeServiceClient - KeyManagement keymanagement.KeyManagementServiceClient - wellknownConfiguration wellknownconfiguration.WellKnownServiceClient + Actions sdkconnect.ActionServiceClient + Attributes sdkconnect.AttributesServiceClient + Authorization sdkconnect.AuthorizationServiceClient + AuthorizationV2 sdkconnect.AuthorizationServiceClientV2 + EntityResoution sdkconnect.EntityResolutionServiceClient + EntityResolutionV2 sdkconnect.EntityResolutionServiceClientV2 + KeyAccessServerRegistry sdkconnect.KeyAccessServerRegistryServiceClient + Namespaces sdkconnect.NamespaceServiceClient + RegisteredResources sdkconnect.RegisteredResourcesServiceClient + ResourceMapping sdkconnect.ResourceMappingServiceClient + SubjectMapping sdkconnect.SubjectMappingServiceClient + Unsafe sdkconnect.UnsafeServiceClient + KeyManagement sdkconnect.KeyManagementServiceClient + wellknownConfiguration sdkconnect.WellKnownServiceClient } func New(platformEndpoint string, opts ...Option) (*SDK, error) { var ( - platformConn *grpc.ClientConn // Connection to the platform - ersConn *grpc.ClientConn // Connection to ERS (possibly remote) + platformConn *ConnectRPCConnection // Connection to the platform + ersConn *ConnectRPCConnection // Connection to ERS (possibly remote) err error ) // Set default options cfg := &config{ - dialOption: grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + httpClient: httputil.SafeHTTPClientWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, - })), + }), } // Apply options @@ -118,25 +104,22 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { cfg.kasSessionKey = &key } - // once we change KAS to use standard DPoP we can put this all in the `build()` method - dialOptions := append([]grpc.DialOption{}, cfg.build()...) - // Add extra grpc dial options if provided. This is useful during tests. - if len(cfg.extraDialOptions) > 0 { - dialOptions = append(dialOptions, cfg.extraDialOptions...) - } - - unsanitizedPlatformEndpoint := platformEndpoint // IF IPC is disabled we build a validated healthy connection to the platform - if !cfg.ipc { - platformEndpoint, err = SanitizePlatformEndpoint(platformEndpoint) - if err != nil { - return nil, fmt.Errorf("%w [%v]: %w", ErrPlatformEndpointMalformed, platformEndpoint, err) + if !cfg.ipc { //nolint:nestif // Most of checks are for errors + if IsPlatformEndpointMalformed(platformEndpoint) { + return nil, fmt.Errorf("%w [%v]", ErrPlatformEndpointMalformed, platformEndpoint) } - if cfg.shouldValidatePlatformConnectivity { - err = ValidateHealthyPlatformConnection(platformEndpoint, dialOptions) - if err != nil { - return nil, err + if cfg.coreConn != nil { + err = validateHealthyPlatformConnection(cfg.coreConn.Endpoint, cfg.coreConn.Client, cfg.coreConn.Options) + if err != nil { + return nil, err + } + } else { + err = validateHealthyPlatformConnection(platformEndpoint, cfg.httpClient, cfg.extraClientOptions) + if err != nil { + return nil, err + } } } } @@ -152,7 +135,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { return nil, errors.Join(ErrPlatformConfigFailed, err) } } else { - pcfg, err = fetchPlatformConfiguration(platformEndpoint, dialOptions) + pcfg, err = getPlatformConfiguration(&ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: cfg.extraClientOptions}) if err != nil { return nil, errors.Join(ErrPlatformConfigFailed, err) } @@ -166,13 +149,13 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { } } if cfg.PlatformConfiguration != nil { - cfg.PlatformConfiguration["platform_endpoint"] = unsanitizedPlatformEndpoint + cfg.PlatformConfiguration["platform_endpoint"] = platformEndpoint } - var uci []grpc.UnaryClientInterceptor + var uci []connect.Interceptor // Add request ID interceptor - uci = append(uci, audit.MetadataAddingClientInterceptor) + uci = append(uci, audit.MetadataAddingConnectInterceptor()) accessTokenSource, err := buildIDPTokenSource(cfg) if err != nil { @@ -180,19 +163,14 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { } if accessTokenSource != nil { interceptor := auth.NewTokenAddingInterceptorWithClient(accessTokenSource, cfg.httpClient) - uci = append(uci, interceptor.AddCredentials) + uci = append(uci, interceptor.AddCredentialsConnect()) } - dialOptions = append(dialOptions, grpc.WithChainUnaryInterceptor(uci...)) - // If coreConn is provided, use it as the platform connection if cfg.coreConn != nil { platformConn = cfg.coreConn } else { - platformConn, err = grpc.NewClient(platformEndpoint, dialOptions...) - if err != nil { - return nil, errors.Join(ErrGrpcDialFailed, err) - } + platformConn = &ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: append(cfg.extraClientOptions, connect.WithInterceptors(uci...))} } if cfg.entityResolutionConn != nil { @@ -205,58 +183,31 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { config: *cfg, collectionStore: cfg.collectionStore, kasKeyCache: newKasKeyCache(), - conn: platformConn, - dialOptions: dialOptions, + conn: &ConnectRPCConnection{Client: platformConn.Client, Endpoint: platformConn.Endpoint, Options: platformConn.Options}, tokenSource: accessTokenSource, - Actions: actions.NewActionServiceClient(platformConn), - Attributes: attributes.NewAttributesServiceClient(platformConn), - Authorization: authorization.NewAuthorizationServiceClient(platformConn), - AuthorizationV2: authorizationV2.NewAuthorizationServiceClient(platformConn), - EntityResoution: entityresolution.NewEntityResolutionServiceClient(ersConn), - EntityResolutionV2: entityresolutionV2.NewEntityResolutionServiceClient(ersConn), - KeyAccessServerRegistry: kasregistry.NewKeyAccessServerRegistryServiceClient(platformConn), - KeyManagement: keymanagement.NewKeyManagementServiceClient(platformConn), - Namespaces: namespaces.NewNamespaceServiceClient(platformConn), - RegisteredResources: registeredresources.NewRegisteredResourcesServiceClient(platformConn), - ResourceMapping: resourcemapping.NewResourceMappingServiceClient(platformConn), - SubjectMapping: subjectmapping.NewSubjectMappingServiceClient(platformConn), - Unsafe: unsafe.NewUnsafeServiceClient(platformConn), - wellknownConfiguration: wellknownconfiguration.NewWellKnownServiceClient(platformConn), + Actions: sdkconnect.NewActionServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Attributes: sdkconnect.NewAttributesServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Namespaces: sdkconnect.NewNamespaceServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + RegisteredResources: sdkconnect.NewRegisteredResourcesServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + ResourceMapping: sdkconnect.NewResourceMappingServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + SubjectMapping: sdkconnect.NewSubjectMappingServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Unsafe: sdkconnect.NewUnsafeServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + KeyAccessServerRegistry: sdkconnect.NewKeyAccessServerRegistryServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Authorization: sdkconnect.NewAuthorizationServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + AuthorizationV2: sdkconnect.NewAuthorizationServiceClientV2ConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + EntityResoution: sdkconnect.NewEntityResolutionServiceClientConnectWrapper(ersConn.Client, ersConn.Endpoint, ersConn.Options...), + EntityResolutionV2: sdkconnect.NewEntityResolutionServiceClientV2ConnectWrapper(ersConn.Client, ersConn.Endpoint, ersConn.Options...), + KeyManagement: sdkconnect.NewKeyManagementServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + wellknownConfiguration: sdkconnect.NewWellKnownServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), }, nil } -func SanitizePlatformEndpoint(e string) (string, error) { - // check if there's a scheme, if not, add https +func IsPlatformEndpointMalformed(e string) bool { u, err := url.ParseRequestURI(e) - if err != nil { - return "", errors.Join(fmt.Errorf("cannot parse platform endpoint [%s]", e), err) - } - if u.Host == "" { - // if the schema is missing add https. when the schema is missing the host is parsed as the scheme - newE := "https://" + e - u, err = url.ParseRequestURI(newE) - if err != nil { - return "", errors.Join(fmt.Errorf("cannot parse platform endpoint [%s]", newE), err) - } - if u.Host == "" { - return "", fmt.Errorf("invalid URL [%s], got empty hostname", newE) - } + if err != nil || u.Hostname() == "" || strings.Contains(u.Hostname(), ":") { + return true } - - if strings.Contains(u.Hostname(), ":") { - return "", fmt.Errorf("invalid hostname [%s]. IPv6 addresses are not supported", u.Hostname()) - } - - p := u.Port() - if p == "" { - if u.Scheme == "http" { - p = "80" - } else { - p = "443" - } - } - - return net.JoinHostPort(u.Hostname(), p), nil + return false } func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { @@ -309,23 +260,15 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { return ts, err } -// Close closes the underlying grpc.ClientConn. func (s SDK) Close() error { if s.collectionStore != nil { s.collectionStore.close() } - - if s.conn == nil { - return nil - } - if err := s.conn.Close(); err != nil { - return errors.Join(ErrShutdownFailed, err) - } return nil } -// Conn returns the underlying grpc.ClientConn. -func (s SDK) Conn() *grpc.ClientConn { +// Conn returns the underlying http connection +func (s SDK) Conn() *ConnectRPCConnection { return s.conn } @@ -450,43 +393,34 @@ func IsValidNanoTdf(reader io.ReadSeeker) (bool, error) { return err == nil, err } -func fetchPlatformConfiguration(platformEndpoint string, dialOptions []grpc.DialOption) (PlatformConfiguration, error) { - conn, err := grpc.NewClient(platformEndpoint, dialOptions...) - if err != nil { - return nil, errors.Join(ErrGrpcDialFailed, err) - } - defer conn.Close() - - return getPlatformConfiguration(conn) -} - // Test connectability to the platform and validate a healthy status -func ValidateHealthyPlatformConnection(platformEndpoint string, dialOptions []grpc.DialOption) error { - conn, err := grpc.NewClient(platformEndpoint, dialOptions...) - if err != nil { - return errors.Join(ErrGrpcDialFailed, err) - } - defer conn.Close() - - req := healthpb.HealthCheckRequest{} - healthService := healthpb.NewHealthClient(conn) - resp, err := healthService.Check(context.Background(), &req) - if err != nil || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { +func validateHealthyPlatformConnection(platformEndpoint string, httpClient *http.Client, options []connect.ClientOption) error { + healthClient := connect.NewClient[healthpb.HealthCheckRequest, healthpb.HealthCheckResponse]( + httpClient, + platformEndpoint+"/grpc.health.v1.Health/Check", + options..., + ) + res, err := healthClient.CallUnary( + context.Background(), + connect.NewRequest(&healthpb.HealthCheckRequest{}), + ) + if err != nil || res.Msg.GetStatus() != healthpb.HealthCheckResponse_SERVING { return errors.Join(ErrPlatformUnreachable, err) } + return nil } -func getPlatformConfiguration(conn *grpc.ClientConn) (PlatformConfiguration, error) { +func getPlatformConfiguration(conn *ConnectRPCConnection) (PlatformConfiguration, error) { req := wellknownconfiguration.GetWellKnownConfigurationRequest{} - wellKnownConfig := wellknownconfiguration.NewWellKnownServiceClient(conn) + wellKnownConfig := wellknownconfigurationconnect.NewWellKnownServiceClient(conn.Client, conn.Endpoint) - response, err := wellKnownConfig.GetWellKnownConfiguration(context.Background(), &req) + response, err := wellKnownConfig.GetWellKnownConfiguration(context.Background(), connect.NewRequest(&req)) if err != nil { return nil, errors.Join(errors.New("unable to retrieve config information, and none was provided"), err) } // Get token endpoint - configuration := response.GetConfiguration() + configuration := response.Msg.GetConfiguration() return configuration.AsMap(), nil } diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index 12b1d431f7..12c0627624 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -6,18 +6,18 @@ import ( "reflect" "testing" - "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/kasregistry" - "github.com/opentdf/platform/protocol/go/policy/resourcemapping" - "github.com/opentdf/platform/protocol/go/policy/subjectmapping" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" "github.com/opentdf/platform/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( - goodPlatformEndpoint = "localhost:8080" - badPlatformEndpoint = "localhost:9999" + goodPlatformEndpoint = "http://localhost:8080" + badPlatformEndpoint = "http://localhost:9999" ) func GetMethods(i interface{}) []string { @@ -122,18 +122,6 @@ func Test_ShouldCreateNewSDK_NoCredentials(t *testing.T) { assert.NotNil(t, s) } -func TestNew_ShouldCloseConnections(t *testing.T) { - s, err := sdk.New(goodPlatformEndpoint, - sdk.WithPlatformConfiguration(sdk.PlatformConfiguration{ - "platform_issuer": "https://example.org", - }), - sdk.WithClientCredentials("myid", "mysecret", nil), - sdk.WithTokenEndpoint("https://example.org/token"), - ) - require.NoError(t, err) - require.NoError(t, s.Close()) -} - func TestNew_ShouldValidateGoodNanoTdf(t *testing.T) { goodNanoTdfStr := "TDFMABJsb2NhbGhvc3Q6ODA4MC9rYXOAAQIA2qvjMRfg7b27lT2kf9SwHRkDIg8ZXtfRoiIvdMUHq/gL5AUMfmv4Di8sKCyLkmUm/WITVj5hDeV/z4JmQ0JL7ZxqSmgZoK6TAHvkKhUly4zMEWMRXH8IktKhFKy1+fD+3qwDopqWAO5Nm2nYQqi75atEFckstulpNKg3N+Ul22OHr/ZuR127oPObBDYNRfktBdzoZbEQcPlr8q1B57q6y5SPZFjEzL9weK+uS5bUJWkF3nsHASo2bZw7IPhTZxoFVmCDjwvj6MbxNa7zG6aClHJ162zKxLLnD9TtIHuZ59R7LgiSieipXeExj+ky9OgIw5DfwyUuxsQLtKpMIAFPmLY9Hy2naUJxke0MT1EUBgastCq+YtFGslV9LJo/A8FtrRqludwtM0O+Z9FlAkZ1oNL7M7uOkLrh7eRrv+C1AAAX6FaBQoOtqnmyu6Jp+VzkxDddEeLRUyI=" goodDecodedData, err := base64.StdEncoding.DecodeString(goodNanoTdfStr) @@ -228,22 +216,22 @@ func TestNew_ShouldHaveSameMethods(t *testing.T) { }{ { name: "Attributes", - expected: GetMethods(reflect.TypeOf(attributes.NewAttributesServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(attributesconnect.NewAttributesServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.Attributes)), }, { name: "ResourceEncoding", - expected: GetMethods(reflect.TypeOf(resourcemapping.NewResourceMappingServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(resourcemappingconnect.NewResourceMappingServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.ResourceMapping)), }, { name: "SubjectEncoding", - expected: GetMethods(reflect.TypeOf(subjectmapping.NewSubjectMappingServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(subjectmappingconnect.NewSubjectMappingServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.SubjectMapping)), }, { name: "KeyAccessGrants", - expected: GetMethods(reflect.TypeOf(kasregistry.NewKeyAccessServerRegistryServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(kasregistryconnect.NewKeyAccessServerRegistryServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.KeyAccessServerRegistry)), }, } @@ -273,93 +261,67 @@ func Test_New_ShouldFailWithDisconnectedPlatform(t *testing.T) { assert.Nil(t, s) } -func Test_ShouldSanitizePlatformEndpoint(t *testing.T) { +func TestIsPlatformEndpointMalformed(t *testing.T) { tests := []struct { - name string - endpoint string - expected string + name string + input string + expected bool + description string }{ { - name: "No scheme", - endpoint: "localhost:8080", - expected: "localhost:8080", - }, - { - name: "HTTP scheme with port", - endpoint: "http://localhost:8080", - expected: "localhost:8080", - }, - { - name: "HTTPS scheme with port", - endpoint: "https://localhost:8080", - expected: "localhost:8080", - }, - { - name: "HTTP scheme no port", - endpoint: "http://localhost", - expected: "localhost:80", - }, - { - name: "HTTPS scheme no port", - endpoint: "https://localhost", - expected: "localhost:443", - }, - { - name: "HTTPS scheme port (IP)", - endpoint: "https://192.168.1.1:8080", - expected: "192.168.1.1:8080", - }, - { - name: "HTTPS scheme no port (IP)", - endpoint: "https://192.168.1.1", - expected: "192.168.1.1:443", + name: "Valid URL with scheme and host", + input: "https://example.com", + expected: false, + description: "A valid URL with scheme and host should not be considered malformed.", }, { - name: "Malformed url", - endpoint: "http://localhost:8080:8080", - expected: "", + name: "Valid URL with scheme, host, and port", + input: "https://example.com:8080", + expected: false, + description: "A valid URL with scheme, host, and port should not be considered malformed.", }, { - name: "Malformed url", - endpoint: "http://localhost:8080:", - expected: "", + name: "Valid URL with path", + input: "https://example.com/path", + expected: false, + description: "A valid URL with a path should not be considered malformed.", }, { - name: "Malformed url", - endpoint: "http//localhost:8080:", - expected: "", + name: "Invalid URL with missing host", + input: "https://:8080", + expected: true, + description: "A URL with a missing host should be considered malformed.", }, { - name: "Malformed url", - endpoint: "//localhost", - expected: "", + name: "Invalid URL with missing scheme", + input: "example.com", + expected: true, + description: "A URL without a scheme should be considered malformed.", }, { - name: "Malformed url", - endpoint: "://localhost", - expected: "", + name: "Invalid URL with invalid characters", + input: "https://exa mple.com", + expected: true, + description: "A URL with invalid characters should be considered malformed.", }, { - name: "Malformed url", - endpoint: "http/localhost", - expected: "", + name: "Invalid URL with colon in hostname", + input: "https://example:com", + expected: true, + description: "A URL with a colon in the hostname should be considered malformed.", }, { - name: "Malformed url", - endpoint: "http:localhost", - expected: "", + name: "Empty input", + input: "", + expected: true, + description: "An empty input should be considered malformed.", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - actual, err := sdk.SanitizePlatformEndpoint(tt.endpoint) - if tt.expected == "" { - require.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected, actual) - } + result := sdk.IsPlatformEndpointMalformed(tt.input) + assert.Equal(t, tt.expected, result, tt.description) }) } } diff --git a/sdk/sdkconnect/actions.go b/sdk/sdkconnect/actions.go new file mode 100644 index 0000000000..1a8177c8f7 --- /dev/null +++ b/sdk/sdkconnect/actions.go @@ -0,0 +1,70 @@ +// Wrapper for ActionServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/actions" + "github.com/opentdf/platform/protocol/go/policy/actions/actionsconnect" +) + +type ActionServiceClientConnectWrapper struct { + actionsconnect.ActionServiceClient +} + +func NewActionServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *ActionServiceClientConnectWrapper { + return &ActionServiceClientConnectWrapper{ActionServiceClient: actionsconnect.NewActionServiceClient(httpClient, baseURL, opts...)} +} + +type ActionServiceClient interface { + GetAction(ctx context.Context, req *actions.GetActionRequest) (*actions.GetActionResponse, error) + ListActions(ctx context.Context, req *actions.ListActionsRequest) (*actions.ListActionsResponse, error) + CreateAction(ctx context.Context, req *actions.CreateActionRequest) (*actions.CreateActionResponse, error) + UpdateAction(ctx context.Context, req *actions.UpdateActionRequest) (*actions.UpdateActionResponse, error) + DeleteAction(ctx context.Context, req *actions.DeleteActionRequest) (*actions.DeleteActionResponse, error) +} + +func (w *ActionServiceClientConnectWrapper) GetAction(ctx context.Context, req *actions.GetActionRequest) (*actions.GetActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.GetAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) ListActions(ctx context.Context, req *actions.ListActionsRequest) (*actions.ListActionsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.ListActions(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) CreateAction(ctx context.Context, req *actions.CreateActionRequest) (*actions.CreateActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.CreateAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) UpdateAction(ctx context.Context, req *actions.UpdateActionRequest) (*actions.UpdateActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.UpdateAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) DeleteAction(ctx context.Context, req *actions.DeleteActionRequest) (*actions.DeleteActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.DeleteAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/attributes.go b/sdk/sdkconnect/attributes.go new file mode 100644 index 0000000000..080ef4df83 --- /dev/null +++ b/sdk/sdkconnect/attributes.go @@ -0,0 +1,210 @@ +// Wrapper for AttributesServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" +) + +type AttributesServiceClientConnectWrapper struct { + attributesconnect.AttributesServiceClient +} + +func NewAttributesServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *AttributesServiceClientConnectWrapper { + return &AttributesServiceClientConnectWrapper{AttributesServiceClient: attributesconnect.NewAttributesServiceClient(httpClient, baseURL, opts...)} +} + +type AttributesServiceClient interface { + ListAttributes(ctx context.Context, req *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) + ListAttributeValues(ctx context.Context, req *attributes.ListAttributeValuesRequest) (*attributes.ListAttributeValuesResponse, error) + GetAttribute(ctx context.Context, req *attributes.GetAttributeRequest) (*attributes.GetAttributeResponse, error) + GetAttributeValuesByFqns(ctx context.Context, req *attributes.GetAttributeValuesByFqnsRequest) (*attributes.GetAttributeValuesByFqnsResponse, error) + CreateAttribute(ctx context.Context, req *attributes.CreateAttributeRequest) (*attributes.CreateAttributeResponse, error) + UpdateAttribute(ctx context.Context, req *attributes.UpdateAttributeRequest) (*attributes.UpdateAttributeResponse, error) + DeactivateAttribute(ctx context.Context, req *attributes.DeactivateAttributeRequest) (*attributes.DeactivateAttributeResponse, error) + GetAttributeValue(ctx context.Context, req *attributes.GetAttributeValueRequest) (*attributes.GetAttributeValueResponse, error) + CreateAttributeValue(ctx context.Context, req *attributes.CreateAttributeValueRequest) (*attributes.CreateAttributeValueResponse, error) + UpdateAttributeValue(ctx context.Context, req *attributes.UpdateAttributeValueRequest) (*attributes.UpdateAttributeValueResponse, error) + DeactivateAttributeValue(ctx context.Context, req *attributes.DeactivateAttributeValueRequest) (*attributes.DeactivateAttributeValueResponse, error) + AssignKeyAccessServerToAttribute(ctx context.Context, req *attributes.AssignKeyAccessServerToAttributeRequest) (*attributes.AssignKeyAccessServerToAttributeResponse, error) + RemoveKeyAccessServerFromAttribute(ctx context.Context, req *attributes.RemoveKeyAccessServerFromAttributeRequest) (*attributes.RemoveKeyAccessServerFromAttributeResponse, error) + AssignKeyAccessServerToValue(ctx context.Context, req *attributes.AssignKeyAccessServerToValueRequest) (*attributes.AssignKeyAccessServerToValueResponse, error) + RemoveKeyAccessServerFromValue(ctx context.Context, req *attributes.RemoveKeyAccessServerFromValueRequest) (*attributes.RemoveKeyAccessServerFromValueResponse, error) + AssignPublicKeyToAttribute(ctx context.Context, req *attributes.AssignPublicKeyToAttributeRequest) (*attributes.AssignPublicKeyToAttributeResponse, error) + RemovePublicKeyFromAttribute(ctx context.Context, req *attributes.RemovePublicKeyFromAttributeRequest) (*attributes.RemovePublicKeyFromAttributeResponse, error) + AssignPublicKeyToValue(ctx context.Context, req *attributes.AssignPublicKeyToValueRequest) (*attributes.AssignPublicKeyToValueResponse, error) + RemovePublicKeyFromValue(ctx context.Context, req *attributes.RemovePublicKeyFromValueRequest) (*attributes.RemovePublicKeyFromValueResponse, error) +} + +func (w *AttributesServiceClientConnectWrapper) ListAttributes(ctx context.Context, req *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.ListAttributes(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) ListAttributeValues(ctx context.Context, req *attributes.ListAttributeValuesRequest) (*attributes.ListAttributeValuesResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.ListAttributeValues(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) GetAttribute(ctx context.Context, req *attributes.GetAttributeRequest) (*attributes.GetAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.GetAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) GetAttributeValuesByFqns(ctx context.Context, req *attributes.GetAttributeValuesByFqnsRequest) (*attributes.GetAttributeValuesByFqnsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.GetAttributeValuesByFqns(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) CreateAttribute(ctx context.Context, req *attributes.CreateAttributeRequest) (*attributes.CreateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.CreateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) UpdateAttribute(ctx context.Context, req *attributes.UpdateAttributeRequest) (*attributes.UpdateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.UpdateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) DeactivateAttribute(ctx context.Context, req *attributes.DeactivateAttributeRequest) (*attributes.DeactivateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.DeactivateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) GetAttributeValue(ctx context.Context, req *attributes.GetAttributeValueRequest) (*attributes.GetAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.GetAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) CreateAttributeValue(ctx context.Context, req *attributes.CreateAttributeValueRequest) (*attributes.CreateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.CreateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) UpdateAttributeValue(ctx context.Context, req *attributes.UpdateAttributeValueRequest) (*attributes.UpdateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.UpdateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) DeactivateAttributeValue(ctx context.Context, req *attributes.DeactivateAttributeValueRequest) (*attributes.DeactivateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.DeactivateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToAttribute(ctx context.Context, req *attributes.AssignKeyAccessServerToAttributeRequest) (*attributes.AssignKeyAccessServerToAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignKeyAccessServerToAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromAttribute(ctx context.Context, req *attributes.RemoveKeyAccessServerFromAttributeRequest) (*attributes.RemoveKeyAccessServerFromAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemoveKeyAccessServerFromAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToValue(ctx context.Context, req *attributes.AssignKeyAccessServerToValueRequest) (*attributes.AssignKeyAccessServerToValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignKeyAccessServerToValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromValue(ctx context.Context, req *attributes.RemoveKeyAccessServerFromValueRequest) (*attributes.RemoveKeyAccessServerFromValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemoveKeyAccessServerFromValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToAttribute(ctx context.Context, req *attributes.AssignPublicKeyToAttributeRequest) (*attributes.AssignPublicKeyToAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignPublicKeyToAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromAttribute(ctx context.Context, req *attributes.RemovePublicKeyFromAttributeRequest) (*attributes.RemovePublicKeyFromAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemovePublicKeyFromAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToValue(ctx context.Context, req *attributes.AssignPublicKeyToValueRequest) (*attributes.AssignPublicKeyToValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignPublicKeyToValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromValue(ctx context.Context, req *attributes.RemovePublicKeyFromValueRequest) (*attributes.RemovePublicKeyFromValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemovePublicKeyFromValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/authorization.go b/sdk/sdkconnect/authorization.go new file mode 100644 index 0000000000..a912aea95e --- /dev/null +++ b/sdk/sdkconnect/authorization.go @@ -0,0 +1,50 @@ +// Wrapper for AuthorizationServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/authorization" + "github.com/opentdf/platform/protocol/go/authorization/authorizationconnect" +) + +type AuthorizationServiceClientConnectWrapper struct { + authorizationconnect.AuthorizationServiceClient +} + +func NewAuthorizationServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *AuthorizationServiceClientConnectWrapper { + return &AuthorizationServiceClientConnectWrapper{AuthorizationServiceClient: authorizationconnect.NewAuthorizationServiceClient(httpClient, baseURL, opts...)} +} + +type AuthorizationServiceClient interface { + GetDecisions(ctx context.Context, req *authorization.GetDecisionsRequest) (*authorization.GetDecisionsResponse, error) + GetDecisionsByToken(ctx context.Context, req *authorization.GetDecisionsByTokenRequest) (*authorization.GetDecisionsByTokenResponse, error) + GetEntitlements(ctx context.Context, req *authorization.GetEntitlementsRequest) (*authorization.GetEntitlementsResponse, error) +} + +func (w *AuthorizationServiceClientConnectWrapper) GetDecisions(ctx context.Context, req *authorization.GetDecisionsRequest) (*authorization.GetDecisionsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisions(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientConnectWrapper) GetDecisionsByToken(ctx context.Context, req *authorization.GetDecisionsByTokenRequest) (*authorization.GetDecisionsByTokenResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisionsByToken(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientConnectWrapper) GetEntitlements(ctx context.Context, req *authorization.GetEntitlementsRequest) (*authorization.GetEntitlementsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetEntitlements(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/authorizationv2.go b/sdk/sdkconnect/authorizationv2.go new file mode 100644 index 0000000000..80e5b1bce3 --- /dev/null +++ b/sdk/sdkconnect/authorizationv2.go @@ -0,0 +1,60 @@ +// Wrapper for AuthorizationServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/authorization/v2" + "github.com/opentdf/platform/protocol/go/authorization/v2/authorizationv2connect" +) + +type AuthorizationServiceClientV2ConnectWrapper struct { + authorizationv2connect.AuthorizationServiceClient +} + +func NewAuthorizationServiceClientV2ConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *AuthorizationServiceClientV2ConnectWrapper { + return &AuthorizationServiceClientV2ConnectWrapper{AuthorizationServiceClient: authorizationv2connect.NewAuthorizationServiceClient(httpClient, baseURL, opts...)} +} + +type AuthorizationServiceClientV2 interface { + GetDecision(ctx context.Context, req *authorizationv2.GetDecisionRequest) (*authorizationv2.GetDecisionResponse, error) + GetDecisionMultiResource(ctx context.Context, req *authorizationv2.GetDecisionMultiResourceRequest) (*authorizationv2.GetDecisionMultiResourceResponse, error) + GetDecisionBulk(ctx context.Context, req *authorizationv2.GetDecisionBulkRequest) (*authorizationv2.GetDecisionBulkResponse, error) + GetEntitlements(ctx context.Context, req *authorizationv2.GetEntitlementsRequest) (*authorizationv2.GetEntitlementsResponse, error) +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetDecision(ctx context.Context, req *authorizationv2.GetDecisionRequest) (*authorizationv2.GetDecisionResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecision(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetDecisionMultiResource(ctx context.Context, req *authorizationv2.GetDecisionMultiResourceRequest) (*authorizationv2.GetDecisionMultiResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisionMultiResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetDecisionBulk(ctx context.Context, req *authorizationv2.GetDecisionBulkRequest) (*authorizationv2.GetDecisionBulkResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisionBulk(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetEntitlements(ctx context.Context, req *authorizationv2.GetEntitlementsRequest) (*authorizationv2.GetEntitlementsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetEntitlements(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/entityresolution.go b/sdk/sdkconnect/entityresolution.go new file mode 100644 index 0000000000..71c331a477 --- /dev/null +++ b/sdk/sdkconnect/entityresolution.go @@ -0,0 +1,40 @@ +// Wrapper for EntityResolutionServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/entityresolution" + "github.com/opentdf/platform/protocol/go/entityresolution/entityresolutionconnect" +) + +type EntityResolutionServiceClientConnectWrapper struct { + entityresolutionconnect.EntityResolutionServiceClient +} + +func NewEntityResolutionServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *EntityResolutionServiceClientConnectWrapper { + return &EntityResolutionServiceClientConnectWrapper{EntityResolutionServiceClient: entityresolutionconnect.NewEntityResolutionServiceClient(httpClient, baseURL, opts...)} +} + +type EntityResolutionServiceClient interface { + ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) + CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) +} + +func (w *EntityResolutionServiceClientConnectWrapper) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.ResolveEntities(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *EntityResolutionServiceClientConnectWrapper) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.CreateEntityChainFromJwt(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/entityresolutionv2.go b/sdk/sdkconnect/entityresolutionv2.go new file mode 100644 index 0000000000..b0173972bc --- /dev/null +++ b/sdk/sdkconnect/entityresolutionv2.go @@ -0,0 +1,40 @@ +// Wrapper for EntityResolutionServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/entityresolution/v2" + "github.com/opentdf/platform/protocol/go/entityresolution/v2/entityresolutionv2connect" +) + +type EntityResolutionServiceClientV2ConnectWrapper struct { + entityresolutionv2connect.EntityResolutionServiceClient +} + +func NewEntityResolutionServiceClientV2ConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *EntityResolutionServiceClientV2ConnectWrapper { + return &EntityResolutionServiceClientV2ConnectWrapper{EntityResolutionServiceClient: entityresolutionv2connect.NewEntityResolutionServiceClient(httpClient, baseURL, opts...)} +} + +type EntityResolutionServiceClientV2 interface { + ResolveEntities(ctx context.Context, req *entityresolutionv2.ResolveEntitiesRequest) (*entityresolutionv2.ResolveEntitiesResponse, error) + CreateEntityChainsFromTokens(ctx context.Context, req *entityresolutionv2.CreateEntityChainsFromTokensRequest) (*entityresolutionv2.CreateEntityChainsFromTokensResponse, error) +} + +func (w *EntityResolutionServiceClientV2ConnectWrapper) ResolveEntities(ctx context.Context, req *entityresolutionv2.ResolveEntitiesRequest) (*entityresolutionv2.ResolveEntitiesResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.ResolveEntities(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *EntityResolutionServiceClientV2ConnectWrapper) CreateEntityChainsFromTokens(ctx context.Context, req *entityresolutionv2.CreateEntityChainsFromTokensRequest) (*entityresolutionv2.CreateEntityChainsFromTokensResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.CreateEntityChainsFromTokens(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/kasregistry.go b/sdk/sdkconnect/kasregistry.go new file mode 100644 index 0000000000..addd61099a --- /dev/null +++ b/sdk/sdkconnect/kasregistry.go @@ -0,0 +1,130 @@ +// Wrapper for KeyAccessServerRegistryServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" +) + +type KeyAccessServerRegistryServiceClientConnectWrapper struct { + kasregistryconnect.KeyAccessServerRegistryServiceClient +} + +func NewKeyAccessServerRegistryServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *KeyAccessServerRegistryServiceClientConnectWrapper { + return &KeyAccessServerRegistryServiceClientConnectWrapper{KeyAccessServerRegistryServiceClient: kasregistryconnect.NewKeyAccessServerRegistryServiceClient(httpClient, baseURL, opts...)} +} + +type KeyAccessServerRegistryServiceClient interface { + ListKeyAccessServers(ctx context.Context, req *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) + GetKeyAccessServer(ctx context.Context, req *kasregistry.GetKeyAccessServerRequest) (*kasregistry.GetKeyAccessServerResponse, error) + CreateKeyAccessServer(ctx context.Context, req *kasregistry.CreateKeyAccessServerRequest) (*kasregistry.CreateKeyAccessServerResponse, error) + UpdateKeyAccessServer(ctx context.Context, req *kasregistry.UpdateKeyAccessServerRequest) (*kasregistry.UpdateKeyAccessServerResponse, error) + DeleteKeyAccessServer(ctx context.Context, req *kasregistry.DeleteKeyAccessServerRequest) (*kasregistry.DeleteKeyAccessServerResponse, error) + ListKeyAccessServerGrants(ctx context.Context, req *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) + CreateKey(ctx context.Context, req *kasregistry.CreateKeyRequest) (*kasregistry.CreateKeyResponse, error) + GetKey(ctx context.Context, req *kasregistry.GetKeyRequest) (*kasregistry.GetKeyResponse, error) + ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest) (*kasregistry.ListKeysResponse, error) + UpdateKey(ctx context.Context, req *kasregistry.UpdateKeyRequest) (*kasregistry.UpdateKeyResponse, error) + RotateKey(ctx context.Context, req *kasregistry.RotateKeyRequest) (*kasregistry.RotateKeyResponse, error) +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServers(ctx context.Context, req *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.ListKeyAccessServers(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKeyAccessServer(ctx context.Context, req *kasregistry.GetKeyAccessServerRequest) (*kasregistry.GetKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.GetKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKeyAccessServer(ctx context.Context, req *kasregistry.CreateKeyAccessServerRequest) (*kasregistry.CreateKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.CreateKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKeyAccessServer(ctx context.Context, req *kasregistry.UpdateKeyAccessServerRequest) (*kasregistry.UpdateKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.UpdateKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) DeleteKeyAccessServer(ctx context.Context, req *kasregistry.DeleteKeyAccessServerRequest) (*kasregistry.DeleteKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.DeleteKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServerGrants(ctx context.Context, req *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.ListKeyAccessServerGrants(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKey(ctx context.Context, req *kasregistry.CreateKeyRequest) (*kasregistry.CreateKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.CreateKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKey(ctx context.Context, req *kasregistry.GetKeyRequest) (*kasregistry.GetKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.GetKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest) (*kasregistry.ListKeysResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.ListKeys(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKey(ctx context.Context, req *kasregistry.UpdateKeyRequest) (*kasregistry.UpdateKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.UpdateKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) RotateKey(ctx context.Context, req *kasregistry.RotateKeyRequest) (*kasregistry.RotateKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.RotateKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/keymanagement.go b/sdk/sdkconnect/keymanagement.go new file mode 100644 index 0000000000..c563ff9144 --- /dev/null +++ b/sdk/sdkconnect/keymanagement.go @@ -0,0 +1,70 @@ +// Wrapper for KeyManagementServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/keymanagement" + "github.com/opentdf/platform/protocol/go/policy/keymanagement/keymanagementconnect" +) + +type KeyManagementServiceClientConnectWrapper struct { + keymanagementconnect.KeyManagementServiceClient +} + +func NewKeyManagementServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *KeyManagementServiceClientConnectWrapper { + return &KeyManagementServiceClientConnectWrapper{KeyManagementServiceClient: keymanagementconnect.NewKeyManagementServiceClient(httpClient, baseURL, opts...)} +} + +type KeyManagementServiceClient interface { + CreateProviderConfig(ctx context.Context, req *keymanagement.CreateProviderConfigRequest) (*keymanagement.CreateProviderConfigResponse, error) + GetProviderConfig(ctx context.Context, req *keymanagement.GetProviderConfigRequest) (*keymanagement.GetProviderConfigResponse, error) + ListProviderConfigs(ctx context.Context, req *keymanagement.ListProviderConfigsRequest) (*keymanagement.ListProviderConfigsResponse, error) + UpdateProviderConfig(ctx context.Context, req *keymanagement.UpdateProviderConfigRequest) (*keymanagement.UpdateProviderConfigResponse, error) + DeleteProviderConfig(ctx context.Context, req *keymanagement.DeleteProviderConfigRequest) (*keymanagement.DeleteProviderConfigResponse, error) +} + +func (w *KeyManagementServiceClientConnectWrapper) CreateProviderConfig(ctx context.Context, req *keymanagement.CreateProviderConfigRequest) (*keymanagement.CreateProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.CreateProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) GetProviderConfig(ctx context.Context, req *keymanagement.GetProviderConfigRequest) (*keymanagement.GetProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.GetProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) ListProviderConfigs(ctx context.Context, req *keymanagement.ListProviderConfigsRequest) (*keymanagement.ListProviderConfigsResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.ListProviderConfigs(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) UpdateProviderConfig(ctx context.Context, req *keymanagement.UpdateProviderConfigRequest) (*keymanagement.UpdateProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.UpdateProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) DeleteProviderConfig(ctx context.Context, req *keymanagement.DeleteProviderConfigRequest) (*keymanagement.DeleteProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.DeleteProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/namespaces.go b/sdk/sdkconnect/namespaces.go new file mode 100644 index 0000000000..b1d02386b8 --- /dev/null +++ b/sdk/sdkconnect/namespaces.go @@ -0,0 +1,110 @@ +// Wrapper for NamespaceServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/namespaces" + "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" +) + +type NamespaceServiceClientConnectWrapper struct { + namespacesconnect.NamespaceServiceClient +} + +func NewNamespaceServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *NamespaceServiceClientConnectWrapper { + return &NamespaceServiceClientConnectWrapper{NamespaceServiceClient: namespacesconnect.NewNamespaceServiceClient(httpClient, baseURL, opts...)} +} + +type NamespaceServiceClient interface { + GetNamespace(ctx context.Context, req *namespaces.GetNamespaceRequest) (*namespaces.GetNamespaceResponse, error) + ListNamespaces(ctx context.Context, req *namespaces.ListNamespacesRequest) (*namespaces.ListNamespacesResponse, error) + CreateNamespace(ctx context.Context, req *namespaces.CreateNamespaceRequest) (*namespaces.CreateNamespaceResponse, error) + UpdateNamespace(ctx context.Context, req *namespaces.UpdateNamespaceRequest) (*namespaces.UpdateNamespaceResponse, error) + DeactivateNamespace(ctx context.Context, req *namespaces.DeactivateNamespaceRequest) (*namespaces.DeactivateNamespaceResponse, error) + AssignKeyAccessServerToNamespace(ctx context.Context, req *namespaces.AssignKeyAccessServerToNamespaceRequest) (*namespaces.AssignKeyAccessServerToNamespaceResponse, error) + RemoveKeyAccessServerFromNamespace(ctx context.Context, req *namespaces.RemoveKeyAccessServerFromNamespaceRequest) (*namespaces.RemoveKeyAccessServerFromNamespaceResponse, error) + AssignPublicKeyToNamespace(ctx context.Context, req *namespaces.AssignPublicKeyToNamespaceRequest) (*namespaces.AssignPublicKeyToNamespaceResponse, error) + RemovePublicKeyFromNamespace(ctx context.Context, req *namespaces.RemovePublicKeyFromNamespaceRequest) (*namespaces.RemovePublicKeyFromNamespaceResponse, error) +} + +func (w *NamespaceServiceClientConnectWrapper) GetNamespace(ctx context.Context, req *namespaces.GetNamespaceRequest) (*namespaces.GetNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.GetNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) ListNamespaces(ctx context.Context, req *namespaces.ListNamespacesRequest) (*namespaces.ListNamespacesResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.ListNamespaces(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) CreateNamespace(ctx context.Context, req *namespaces.CreateNamespaceRequest) (*namespaces.CreateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.CreateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) UpdateNamespace(ctx context.Context, req *namespaces.UpdateNamespaceRequest) (*namespaces.UpdateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.UpdateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) DeactivateNamespace(ctx context.Context, req *namespaces.DeactivateNamespaceRequest) (*namespaces.DeactivateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.DeactivateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) AssignKeyAccessServerToNamespace(ctx context.Context, req *namespaces.AssignKeyAccessServerToNamespaceRequest) (*namespaces.AssignKeyAccessServerToNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.AssignKeyAccessServerToNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) RemoveKeyAccessServerFromNamespace(ctx context.Context, req *namespaces.RemoveKeyAccessServerFromNamespaceRequest) (*namespaces.RemoveKeyAccessServerFromNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.RemoveKeyAccessServerFromNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) AssignPublicKeyToNamespace(ctx context.Context, req *namespaces.AssignPublicKeyToNamespaceRequest) (*namespaces.AssignPublicKeyToNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.AssignPublicKeyToNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) RemovePublicKeyFromNamespace(ctx context.Context, req *namespaces.RemovePublicKeyFromNamespaceRequest) (*namespaces.RemovePublicKeyFromNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.RemovePublicKeyFromNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/registeredresources.go b/sdk/sdkconnect/registeredresources.go new file mode 100644 index 0000000000..4f8274d568 --- /dev/null +++ b/sdk/sdkconnect/registeredresources.go @@ -0,0 +1,130 @@ +// Wrapper for RegisteredResourcesServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/registeredresources" + "github.com/opentdf/platform/protocol/go/policy/registeredresources/registeredresourcesconnect" +) + +type RegisteredResourcesServiceClientConnectWrapper struct { + registeredresourcesconnect.RegisteredResourcesServiceClient +} + +func NewRegisteredResourcesServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *RegisteredResourcesServiceClientConnectWrapper { + return &RegisteredResourcesServiceClientConnectWrapper{RegisteredResourcesServiceClient: registeredresourcesconnect.NewRegisteredResourcesServiceClient(httpClient, baseURL, opts...)} +} + +type RegisteredResourcesServiceClient interface { + CreateRegisteredResource(ctx context.Context, req *registeredresources.CreateRegisteredResourceRequest) (*registeredresources.CreateRegisteredResourceResponse, error) + GetRegisteredResource(ctx context.Context, req *registeredresources.GetRegisteredResourceRequest) (*registeredresources.GetRegisteredResourceResponse, error) + ListRegisteredResources(ctx context.Context, req *registeredresources.ListRegisteredResourcesRequest) (*registeredresources.ListRegisteredResourcesResponse, error) + UpdateRegisteredResource(ctx context.Context, req *registeredresources.UpdateRegisteredResourceRequest) (*registeredresources.UpdateRegisteredResourceResponse, error) + DeleteRegisteredResource(ctx context.Context, req *registeredresources.DeleteRegisteredResourceRequest) (*registeredresources.DeleteRegisteredResourceResponse, error) + CreateRegisteredResourceValue(ctx context.Context, req *registeredresources.CreateRegisteredResourceValueRequest) (*registeredresources.CreateRegisteredResourceValueResponse, error) + GetRegisteredResourceValue(ctx context.Context, req *registeredresources.GetRegisteredResourceValueRequest) (*registeredresources.GetRegisteredResourceValueResponse, error) + GetRegisteredResourceValuesByFQNs(ctx context.Context, req *registeredresources.GetRegisteredResourceValuesByFQNsRequest) (*registeredresources.GetRegisteredResourceValuesByFQNsResponse, error) + ListRegisteredResourceValues(ctx context.Context, req *registeredresources.ListRegisteredResourceValuesRequest) (*registeredresources.ListRegisteredResourceValuesResponse, error) + UpdateRegisteredResourceValue(ctx context.Context, req *registeredresources.UpdateRegisteredResourceValueRequest) (*registeredresources.UpdateRegisteredResourceValueResponse, error) + DeleteRegisteredResourceValue(ctx context.Context, req *registeredresources.DeleteRegisteredResourceValueRequest) (*registeredresources.DeleteRegisteredResourceValueResponse, error) +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResource(ctx context.Context, req *registeredresources.CreateRegisteredResourceRequest) (*registeredresources.CreateRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.CreateRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResource(ctx context.Context, req *registeredresources.GetRegisteredResourceRequest) (*registeredresources.GetRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.GetRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResources(ctx context.Context, req *registeredresources.ListRegisteredResourcesRequest) (*registeredresources.ListRegisteredResourcesResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.ListRegisteredResources(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResource(ctx context.Context, req *registeredresources.UpdateRegisteredResourceRequest) (*registeredresources.UpdateRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.UpdateRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResource(ctx context.Context, req *registeredresources.DeleteRegisteredResourceRequest) (*registeredresources.DeleteRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.DeleteRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResourceValue(ctx context.Context, req *registeredresources.CreateRegisteredResourceValueRequest) (*registeredresources.CreateRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.CreateRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValue(ctx context.Context, req *registeredresources.GetRegisteredResourceValueRequest) (*registeredresources.GetRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.GetRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValuesByFQNs(ctx context.Context, req *registeredresources.GetRegisteredResourceValuesByFQNsRequest) (*registeredresources.GetRegisteredResourceValuesByFQNsResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.GetRegisteredResourceValuesByFQNs(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResourceValues(ctx context.Context, req *registeredresources.ListRegisteredResourceValuesRequest) (*registeredresources.ListRegisteredResourceValuesResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.ListRegisteredResourceValues(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResourceValue(ctx context.Context, req *registeredresources.UpdateRegisteredResourceValueRequest) (*registeredresources.UpdateRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.UpdateRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResourceValue(ctx context.Context, req *registeredresources.DeleteRegisteredResourceValueRequest) (*registeredresources.DeleteRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.DeleteRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/resourcemapping.go b/sdk/sdkconnect/resourcemapping.go new file mode 100644 index 0000000000..047168bf48 --- /dev/null +++ b/sdk/sdkconnect/resourcemapping.go @@ -0,0 +1,130 @@ +// Wrapper for ResourceMappingServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" +) + +type ResourceMappingServiceClientConnectWrapper struct { + resourcemappingconnect.ResourceMappingServiceClient +} + +func NewResourceMappingServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *ResourceMappingServiceClientConnectWrapper { + return &ResourceMappingServiceClientConnectWrapper{ResourceMappingServiceClient: resourcemappingconnect.NewResourceMappingServiceClient(httpClient, baseURL, opts...)} +} + +type ResourceMappingServiceClient interface { + ListResourceMappingGroups(ctx context.Context, req *resourcemapping.ListResourceMappingGroupsRequest) (*resourcemapping.ListResourceMappingGroupsResponse, error) + GetResourceMappingGroup(ctx context.Context, req *resourcemapping.GetResourceMappingGroupRequest) (*resourcemapping.GetResourceMappingGroupResponse, error) + CreateResourceMappingGroup(ctx context.Context, req *resourcemapping.CreateResourceMappingGroupRequest) (*resourcemapping.CreateResourceMappingGroupResponse, error) + UpdateResourceMappingGroup(ctx context.Context, req *resourcemapping.UpdateResourceMappingGroupRequest) (*resourcemapping.UpdateResourceMappingGroupResponse, error) + DeleteResourceMappingGroup(ctx context.Context, req *resourcemapping.DeleteResourceMappingGroupRequest) (*resourcemapping.DeleteResourceMappingGroupResponse, error) + ListResourceMappings(ctx context.Context, req *resourcemapping.ListResourceMappingsRequest) (*resourcemapping.ListResourceMappingsResponse, error) + ListResourceMappingsByGroupFqns(ctx context.Context, req *resourcemapping.ListResourceMappingsByGroupFqnsRequest) (*resourcemapping.ListResourceMappingsByGroupFqnsResponse, error) + GetResourceMapping(ctx context.Context, req *resourcemapping.GetResourceMappingRequest) (*resourcemapping.GetResourceMappingResponse, error) + CreateResourceMapping(ctx context.Context, req *resourcemapping.CreateResourceMappingRequest) (*resourcemapping.CreateResourceMappingResponse, error) + UpdateResourceMapping(ctx context.Context, req *resourcemapping.UpdateResourceMappingRequest) (*resourcemapping.UpdateResourceMappingResponse, error) + DeleteResourceMapping(ctx context.Context, req *resourcemapping.DeleteResourceMappingRequest) (*resourcemapping.DeleteResourceMappingResponse, error) +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingGroups(ctx context.Context, req *resourcemapping.ListResourceMappingGroupsRequest) (*resourcemapping.ListResourceMappingGroupsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.ListResourceMappingGroups(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMappingGroup(ctx context.Context, req *resourcemapping.GetResourceMappingGroupRequest) (*resourcemapping.GetResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.GetResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMappingGroup(ctx context.Context, req *resourcemapping.CreateResourceMappingGroupRequest) (*resourcemapping.CreateResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.CreateResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMappingGroup(ctx context.Context, req *resourcemapping.UpdateResourceMappingGroupRequest) (*resourcemapping.UpdateResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.UpdateResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMappingGroup(ctx context.Context, req *resourcemapping.DeleteResourceMappingGroupRequest) (*resourcemapping.DeleteResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.DeleteResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappings(ctx context.Context, req *resourcemapping.ListResourceMappingsRequest) (*resourcemapping.ListResourceMappingsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.ListResourceMappings(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingsByGroupFqns(ctx context.Context, req *resourcemapping.ListResourceMappingsByGroupFqnsRequest) (*resourcemapping.ListResourceMappingsByGroupFqnsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.ListResourceMappingsByGroupFqns(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMapping(ctx context.Context, req *resourcemapping.GetResourceMappingRequest) (*resourcemapping.GetResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.GetResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMapping(ctx context.Context, req *resourcemapping.CreateResourceMappingRequest) (*resourcemapping.CreateResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.CreateResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMapping(ctx context.Context, req *resourcemapping.UpdateResourceMappingRequest) (*resourcemapping.UpdateResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.UpdateResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMapping(ctx context.Context, req *resourcemapping.DeleteResourceMappingRequest) (*resourcemapping.DeleteResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.DeleteResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/subjectmapping.go b/sdk/sdkconnect/subjectmapping.go new file mode 100644 index 0000000000..90640a1d72 --- /dev/null +++ b/sdk/sdkconnect/subjectmapping.go @@ -0,0 +1,140 @@ +// Wrapper for SubjectMappingServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" +) + +type SubjectMappingServiceClientConnectWrapper struct { + subjectmappingconnect.SubjectMappingServiceClient +} + +func NewSubjectMappingServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *SubjectMappingServiceClientConnectWrapper { + return &SubjectMappingServiceClientConnectWrapper{SubjectMappingServiceClient: subjectmappingconnect.NewSubjectMappingServiceClient(httpClient, baseURL, opts...)} +} + +type SubjectMappingServiceClient interface { + MatchSubjectMappings(ctx context.Context, req *subjectmapping.MatchSubjectMappingsRequest) (*subjectmapping.MatchSubjectMappingsResponse, error) + ListSubjectMappings(ctx context.Context, req *subjectmapping.ListSubjectMappingsRequest) (*subjectmapping.ListSubjectMappingsResponse, error) + GetSubjectMapping(ctx context.Context, req *subjectmapping.GetSubjectMappingRequest) (*subjectmapping.GetSubjectMappingResponse, error) + CreateSubjectMapping(ctx context.Context, req *subjectmapping.CreateSubjectMappingRequest) (*subjectmapping.CreateSubjectMappingResponse, error) + UpdateSubjectMapping(ctx context.Context, req *subjectmapping.UpdateSubjectMappingRequest) (*subjectmapping.UpdateSubjectMappingResponse, error) + DeleteSubjectMapping(ctx context.Context, req *subjectmapping.DeleteSubjectMappingRequest) (*subjectmapping.DeleteSubjectMappingResponse, error) + ListSubjectConditionSets(ctx context.Context, req *subjectmapping.ListSubjectConditionSetsRequest) (*subjectmapping.ListSubjectConditionSetsResponse, error) + GetSubjectConditionSet(ctx context.Context, req *subjectmapping.GetSubjectConditionSetRequest) (*subjectmapping.GetSubjectConditionSetResponse, error) + CreateSubjectConditionSet(ctx context.Context, req *subjectmapping.CreateSubjectConditionSetRequest) (*subjectmapping.CreateSubjectConditionSetResponse, error) + UpdateSubjectConditionSet(ctx context.Context, req *subjectmapping.UpdateSubjectConditionSetRequest) (*subjectmapping.UpdateSubjectConditionSetResponse, error) + DeleteSubjectConditionSet(ctx context.Context, req *subjectmapping.DeleteSubjectConditionSetRequest) (*subjectmapping.DeleteSubjectConditionSetResponse, error) + DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *subjectmapping.DeleteAllUnmappedSubjectConditionSetsRequest) (*subjectmapping.DeleteAllUnmappedSubjectConditionSetsResponse, error) +} + +func (w *SubjectMappingServiceClientConnectWrapper) MatchSubjectMappings(ctx context.Context, req *subjectmapping.MatchSubjectMappingsRequest) (*subjectmapping.MatchSubjectMappingsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.MatchSubjectMappings(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectMappings(ctx context.Context, req *subjectmapping.ListSubjectMappingsRequest) (*subjectmapping.ListSubjectMappingsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.ListSubjectMappings(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectMapping(ctx context.Context, req *subjectmapping.GetSubjectMappingRequest) (*subjectmapping.GetSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.GetSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectMapping(ctx context.Context, req *subjectmapping.CreateSubjectMappingRequest) (*subjectmapping.CreateSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.CreateSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectMapping(ctx context.Context, req *subjectmapping.UpdateSubjectMappingRequest) (*subjectmapping.UpdateSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.UpdateSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectMapping(ctx context.Context, req *subjectmapping.DeleteSubjectMappingRequest) (*subjectmapping.DeleteSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.DeleteSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectConditionSets(ctx context.Context, req *subjectmapping.ListSubjectConditionSetsRequest) (*subjectmapping.ListSubjectConditionSetsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.ListSubjectConditionSets(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectConditionSet(ctx context.Context, req *subjectmapping.GetSubjectConditionSetRequest) (*subjectmapping.GetSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.GetSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectConditionSet(ctx context.Context, req *subjectmapping.CreateSubjectConditionSetRequest) (*subjectmapping.CreateSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.CreateSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectConditionSet(ctx context.Context, req *subjectmapping.UpdateSubjectConditionSetRequest) (*subjectmapping.UpdateSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.UpdateSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectConditionSet(ctx context.Context, req *subjectmapping.DeleteSubjectConditionSetRequest) (*subjectmapping.DeleteSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.DeleteSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *subjectmapping.DeleteAllUnmappedSubjectConditionSetsRequest) (*subjectmapping.DeleteAllUnmappedSubjectConditionSetsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.DeleteAllUnmappedSubjectConditionSets(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/unsafe.go b/sdk/sdkconnect/unsafe.go new file mode 100644 index 0000000000..b01b4344d5 --- /dev/null +++ b/sdk/sdkconnect/unsafe.go @@ -0,0 +1,120 @@ +// Wrapper for UnsafeServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/unsafe" + "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" +) + +type UnsafeServiceClientConnectWrapper struct { + unsafeconnect.UnsafeServiceClient +} + +func NewUnsafeServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *UnsafeServiceClientConnectWrapper { + return &UnsafeServiceClientConnectWrapper{UnsafeServiceClient: unsafeconnect.NewUnsafeServiceClient(httpClient, baseURL, opts...)} +} + +type UnsafeServiceClient interface { + UnsafeUpdateNamespace(ctx context.Context, req *unsafe.UnsafeUpdateNamespaceRequest) (*unsafe.UnsafeUpdateNamespaceResponse, error) + UnsafeReactivateNamespace(ctx context.Context, req *unsafe.UnsafeReactivateNamespaceRequest) (*unsafe.UnsafeReactivateNamespaceResponse, error) + UnsafeDeleteNamespace(ctx context.Context, req *unsafe.UnsafeDeleteNamespaceRequest) (*unsafe.UnsafeDeleteNamespaceResponse, error) + UnsafeUpdateAttribute(ctx context.Context, req *unsafe.UnsafeUpdateAttributeRequest) (*unsafe.UnsafeUpdateAttributeResponse, error) + UnsafeReactivateAttribute(ctx context.Context, req *unsafe.UnsafeReactivateAttributeRequest) (*unsafe.UnsafeReactivateAttributeResponse, error) + UnsafeDeleteAttribute(ctx context.Context, req *unsafe.UnsafeDeleteAttributeRequest) (*unsafe.UnsafeDeleteAttributeResponse, error) + UnsafeUpdateAttributeValue(ctx context.Context, req *unsafe.UnsafeUpdateAttributeValueRequest) (*unsafe.UnsafeUpdateAttributeValueResponse, error) + UnsafeReactivateAttributeValue(ctx context.Context, req *unsafe.UnsafeReactivateAttributeValueRequest) (*unsafe.UnsafeReactivateAttributeValueResponse, error) + UnsafeDeleteAttributeValue(ctx context.Context, req *unsafe.UnsafeDeleteAttributeValueRequest) (*unsafe.UnsafeDeleteAttributeValueResponse, error) + UnsafeDeleteKasKey(ctx context.Context, req *unsafe.UnsafeDeleteKasKeyRequest) (*unsafe.UnsafeDeleteKasKeyResponse, error) +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateNamespace(ctx context.Context, req *unsafe.UnsafeUpdateNamespaceRequest) (*unsafe.UnsafeUpdateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeUpdateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateNamespace(ctx context.Context, req *unsafe.UnsafeReactivateNamespaceRequest) (*unsafe.UnsafeReactivateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeReactivateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteNamespace(ctx context.Context, req *unsafe.UnsafeDeleteNamespaceRequest) (*unsafe.UnsafeDeleteNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttribute(ctx context.Context, req *unsafe.UnsafeUpdateAttributeRequest) (*unsafe.UnsafeUpdateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeUpdateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttribute(ctx context.Context, req *unsafe.UnsafeReactivateAttributeRequest) (*unsafe.UnsafeReactivateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeReactivateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttribute(ctx context.Context, req *unsafe.UnsafeDeleteAttributeRequest) (*unsafe.UnsafeDeleteAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttributeValue(ctx context.Context, req *unsafe.UnsafeUpdateAttributeValueRequest) (*unsafe.UnsafeUpdateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeUpdateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttributeValue(ctx context.Context, req *unsafe.UnsafeReactivateAttributeValueRequest) (*unsafe.UnsafeReactivateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeReactivateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttributeValue(ctx context.Context, req *unsafe.UnsafeDeleteAttributeValueRequest) (*unsafe.UnsafeDeleteAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteKasKey(ctx context.Context, req *unsafe.UnsafeDeleteKasKeyRequest) (*unsafe.UnsafeDeleteKasKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteKasKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/wellknownconfiguration.go b/sdk/sdkconnect/wellknownconfiguration.go new file mode 100644 index 0000000000..e635d1e6e0 --- /dev/null +++ b/sdk/sdkconnect/wellknownconfiguration.go @@ -0,0 +1,30 @@ +// Wrapper for WellKnownServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/wellknownconfiguration" + "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" +) + +type WellKnownServiceClientConnectWrapper struct { + wellknownconfigurationconnect.WellKnownServiceClient +} + +func NewWellKnownServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *WellKnownServiceClientConnectWrapper { + return &WellKnownServiceClientConnectWrapper{WellKnownServiceClient: wellknownconfigurationconnect.NewWellKnownServiceClient(httpClient, baseURL, opts...)} +} + +type WellKnownServiceClient interface { + GetWellKnownConfiguration(ctx context.Context, req *wellknownconfiguration.GetWellKnownConfigurationRequest) (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) +} + +func (w *WellKnownServiceClientConnectWrapper) GetWellKnownConfiguration(ctx context.Context, req *wellknownconfiguration.GetWellKnownConfigurationRequest) (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) { + // Wrap Connect RPC client request + res, err := w.WellKnownServiceClient.GetWellKnownConfiguration(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/tdf.go b/sdk/tdf.go index 8aa8e8cbe0..890ca551a1 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -11,9 +11,11 @@ import ( "io" "log/slog" "math" + "net/http" "strconv" "strings" + "connectrpc.com/connect" "github.com/Masterminds/semver/v3" "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/policy/kasregistry" @@ -22,7 +24,7 @@ import ( "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/internal/archive" - "google.golang.org/grpc" + "github.com/opentdf/platform/sdk/sdkconnect" "google.golang.org/grpc/codes" ) @@ -63,7 +65,8 @@ const ( // Loads and reads ZTDF files type Reader struct { tokenSource auth.AccessTokenSource - dialOptions []grpc.DialOption + httpClient *http.Client + connectOptions []connect.ClientOption manifest Manifest unencryptedMetadata []byte tdfReader archive.TDFReader @@ -654,7 +657,7 @@ func createPolicyObject(attributes []AttributeValueFQN) (PolicyObject, error) { return policyObj, nil } -func allowListFromKASRegistry(ctx context.Context, kasRegistryClient kasregistry.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { +func allowListFromKASRegistry(ctx context.Context, kasRegistryClient sdkconnect.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { kases, err := kasRegistryClient.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) if err != nil { return nil, fmt.Errorf("kasregistry.ListKeyAccessServers failed: %w", err) @@ -732,12 +735,13 @@ func (s SDK) LoadTDF(reader io.ReadSeeker, opts ...TDFReaderOption) (*Reader, er } return &Reader{ - tokenSource: s.tokenSource, - dialOptions: s.dialOptions, - tdfReader: tdfReader, - manifest: *manifestObj, - kasSessionKey: config.kasSessionKey, - config: *config, + tokenSource: s.tokenSource, + httpClient: s.conn.Client, + connectOptions: s.conn.Options, + tdfReader: tdfReader, + manifest: *manifestObj, + kasSessionKey: config.kasSessionKey, + config: *config, }, nil } @@ -1248,7 +1252,7 @@ func (r *Reader) buildKey(_ context.Context, results []kaoResult) error { // Unwraps the payload key, if possible, using the access service func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocognit // Better readability keeping it as is - kasClient := newKASClient(r.dialOptions, r.tokenSource, r.kasSessionKey) + kasClient := newKASClient(r.httpClient, r.connectOptions, r.tokenSource, r.kasSessionKey) var kaoResults []kaoResult reqFail := func(err error, req *kas.UnsignedRewrapRequest_WithPolicyRequest) { diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index d83e69cf72..333bd0f9d5 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -16,8 +16,8 @@ import ( "fmt" "io" "log/slog" - "net" - "net/url" + "net/http" + "net/http/httptest" "os" "path/filepath" "strconv" @@ -25,20 +25,23 @@ import ( "testing" "time" + "connectrpc.com/connect" "google.golang.org/protobuf/encoding/protojson" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/lib/ocrypto" kaspb "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/protocol/go/policy" attributespb "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" wellknownpb "github.com/opentdf/platform/protocol/go/wellknownconfiguration" - "google.golang.org/grpc" + wellknownconnect "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" "google.golang.org/grpc/codes" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" - "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/structpb" "github.com/stretchr/testify/assert" @@ -284,8 +287,9 @@ type keyInfo struct { type TDFSuite struct { suite.Suite - sdk *SDK - kases []FakeKas + sdk *SDK + kases []FakeKas + kasTestURLLookup map[string]string } func (s *TDFSuite) SetupSuite() { @@ -324,31 +328,35 @@ func (s *TDFSuite) Test_SimpleTDF() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), }, - tdfReadOptions: []TDFReaderOption{}, + tdfReadOptions: []TDFReaderOption{ + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]}), + }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), WithTargetMode("0.0.0"), }, - tdfReadOptions: []TDFReaderOption{}, - useHex: true, + tdfReadOptions: []TDFReaderOption{ + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]}), + }, + useHex: true, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://d.kas/", + URL: s.kasTestURLLookup["https://d.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -357,12 +365,13 @@ func (s *TDFSuite) Test_SimpleTDF() { }, tdfReadOptions: []TDFReaderOption{ WithSessionKeyType(ocrypto.EC256Key), + WithKasAllowlist([]string{s.kasTestURLLookup["https://d.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://d.kas/", + URL: s.kasTestURLLookup["https://d.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -372,6 +381,7 @@ func (s *TDFSuite) Test_SimpleTDF() { }, tdfReadOptions: []TDFReaderOption{ WithSessionKeyType(ocrypto.EC256Key), + WithKasAllowlist([]string{s.kasTestURLLookup["https://d.kas/"]}), }, useHex: true, }, @@ -491,20 +501,20 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), }, tdfReadOptions: []TDFReaderOption{ - WithKasAllowlist([]string{"https://a.kas/"}), + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -513,12 +523,12 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { tdfReadOptions: []TDFReaderOption{ WithKasAllowlist([]string{"https://nope-not-a-kas.com/kas"}), }, - expectedError: "KasAllowlist: kas url https://a.kas/ is not allowed", + expectedError: "KasAllowlist: kas url " + s.kasTestURLLookup["https://a.kas/"] + " is not allowed", }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -527,12 +537,12 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { tdfReadOptions: []TDFReaderOption{ withKasAllowlist(AllowList{"nope-not-a-kas.com": true}), }, - expectedError: "KasAllowlist: kas url https://a.kas/ is not allowed", + expectedError: "KasAllowlist: kas url " + s.kasTestURLLookup["https://a.kas/"] + " is not allowed", }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -546,7 +556,7 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -822,7 +832,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { { kasURLs := []KASInfo{ { - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }, } @@ -866,11 +876,12 @@ func (s *TDFSuite) Test_TDFWithAssertion() { var r *Reader if test.verifiers == nil { - r, err = s.sdk.LoadTDF(readSeeker, WithDisableAssertionVerification(test.disableAssertionVerification)) + r, err = s.sdk.LoadTDF(readSeeker, WithDisableAssertionVerification(test.disableAssertionVerification), WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]})) } else { r, err = s.sdk.LoadTDF(readSeeker, WithAssertionVerificationKeys(*test.verifiers), - WithDisableAssertionVerification(test.disableAssertionVerification)) + WithDisableAssertionVerification(test.disableAssertionVerification), + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]})) } s.Require().NoError(err) @@ -1140,7 +1151,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() { { kasURLs := []KASInfo{ { - URL: "https://a.kas/", + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }, } @@ -1199,11 +1210,11 @@ func (s *TDFSuite) Test_TDFReader() { //nolint:gocognit // requires for testing payload: payload, // len: 62 kasInfoList: []KASInfo{ { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestURLLookup["http://localhost:65432/"], PublicKey: mockRSAPublicKey1, }, { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestURLLookup["http://localhost:65432/"], PublicKey: mockRSAPublicKey1, }, }, @@ -1299,11 +1310,11 @@ func (s *TDFSuite) Test_TDFReader() { //nolint:gocognit // requires for testing func (s *TDFSuite) Test_TDFReaderFail() { kasInfoList := []KASInfo{ { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestURLLookup["http://localhost:65432/api/kas"], PublicKey: mockRSAPublicKey1, }, { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestURLLookup["http://localhost:65432/api/kas"], PublicKey: mockRSAPublicKey1, }, } @@ -1632,9 +1643,9 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 2759, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: "https://a.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "a"}, - {KAS: `https://c.kas/`, SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup[`https://c.kas/`], SplitID: "a"}, }, }, { @@ -1643,9 +1654,9 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 2759, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: "https://a.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "b"}, - {KAS: "https://c.kas/", SplitID: "c"}, + {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "b"}, + {KAS: s.kasTestURLLookup["https://c.kas/"], SplitID: "c"}, }, }, { @@ -1654,10 +1665,10 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 3351, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: "https://a.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "b"}, - {KAS: "https://c.kas/", SplitID: "b"}, + {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "b"}, + {KAS: s.kasTestURLLookup["https://c.kas/"], SplitID: "b"}, }, }, } { @@ -1881,7 +1892,7 @@ func (s *TDFSuite) startBackend() { } fwk := &FakeWellKnown{v: wellknownCfg} - fa := &FakeAttributes{} + fa := &FakeAttributes{s: s} kasesToMake := []struct { url, private, public string }{ @@ -1898,137 +1909,129 @@ func (s *TDFSuite) startBackend() { {kasNz, mockRSAPrivateKey3, mockRSAPublicKey3}, {kasUs, mockRSAPrivateKey1, mockRSAPublicKey1}, } - fkar := &FakeKASRegistry{kases: kasesToMake} - - listeners := make(map[string]*bufconn.Listener) - dialer := func(ctx context.Context, host string) (net.Conn, error) { - l, ok := listeners[host] - if !ok { - slog.ErrorContext(ctx, "bufconn: unable to dial host!", "ctx", ctx, "host", host) - return nil, fmt.Errorf("unknown host [%s]", host) - } - slog.InfoContext(ctx, "bufconn: dialing (local grpc)", "ctx", ctx, "host", host) - return l.Dial() - } + fkar := &FakeKASRegistry{kases: kasesToMake, s: s} s.kases = make([]FakeKas, 12) + s.kasTestURLLookup = make(map[string]string, 12) + + var sdkPlatformURL string + for i, ki := range kasesToMake { - grpcListener := bufconn.Listen(1024 * 1024) - url, err := url.Parse(ki.url) - s.Require().NoError(err) - var origin string - switch { - case url.Port() == "80": - origin = url.Host - case url.Scheme == "https": - origin = url.Host + ":443" - case url.Port() != "": - origin = url.Host - default: - origin = url.Hostname() - } - listeners[origin] = grpcListener + mux := http.NewServeMux() - grpcServer := grpc.NewServer() s.kases[i] = FakeKas{ s: s, privateKey: ki.private, KASInfo: KASInfo{ URL: ki.url, PublicKey: ki.public, KID: "r1", Algorithm: "rsa:2048", }, legakeys: map[string]keyInfo{}, } - attributespb.RegisterAttributesServiceServer(grpcServer, fa) - kaspb.RegisterAccessServiceServer(grpcServer, &s.kases[i]) - wellknownpb.RegisterWellKnownServiceServer(grpcServer, fwk) - kasregistry.RegisterKeyAccessServerRegistryServiceServer(grpcServer, fkar) - - go func() { - err := grpcServer.Serve(grpcListener) - s.NoError(err) - }() + path, handler := attributesconnect.NewAttributesServiceHandler(fa) + mux.Handle(path, handler) + kasPath, kasHandler := kasconnect.NewAccessServiceHandler(&s.kases[i]) + mux.Handle(kasPath, kasHandler) + path, handler = wellknownconnect.NewWellKnownServiceHandler(fwk) + mux.Handle(path, handler) + path, handler = kasregistryconnect.NewKeyAccessServerRegistryServiceHandler(fkar) + mux.Handle(path, handler) + + server := httptest.NewServer(mux) + + // add to lookup reg + s.kasTestURLLookup[s.kases[i].KASInfo.URL] = server.URL + // replace kasinfo url with httptest server url + s.kases[i].KASInfo.URL = server.URL + + if i == 0 { + sdkPlatformURL = server.URL + } } ats := getTokenSource(s.T()) - sdk, err := New("localhost:65432", + sdk, err := New(sdkPlatformURL, WithClientCredentials("test", "test", nil), withCustomAccessTokenSource(&ats), WithTokenEndpoint("http://localhost:65432/auth/token"), - WithInsecurePlaintextConn(), - WithExtraDialOptions(grpc.WithContextDialer(dialer))) + WithInsecurePlaintextConn()) s.Require().NoError(err) s.sdk = sdk } type FakeWellKnown struct { - wellknownpb.UnimplementedWellKnownServiceServer + wellknownconnect.UnimplementedWellKnownServiceHandler v map[string]interface{} } -func (f *FakeWellKnown) GetWellKnownConfiguration(_ context.Context, _ *wellknownpb.GetWellKnownConfigurationRequest) (*wellknownpb.GetWellKnownConfigurationResponse, error) { +func (f *FakeWellKnown) GetWellKnownConfiguration(_ context.Context, _ *connect.Request[wellknownpb.GetWellKnownConfigurationRequest]) (*connect.Response[wellknownpb.GetWellKnownConfigurationResponse], error) { cfg, err := structpb.NewStruct(f.v) if err != nil { return nil, err } - return &wellknownpb.GetWellKnownConfigurationResponse{ + return connect.NewResponse(&wellknownpb.GetWellKnownConfigurationResponse{ Configuration: cfg, - }, nil + }), nil } type FakeAttributes struct { - attributespb.UnimplementedAttributesServiceServer + attributesconnect.UnimplementedAttributesServiceHandler + s *TDFSuite } -func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *attributespb.GetAttributeValuesByFqnsRequest) (*attributespb.GetAttributeValuesByFqnsResponse, error) { +func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *connect.Request[attributespb.GetAttributeValuesByFqnsRequest]) (*connect.Response[attributespb.GetAttributeValuesByFqnsResponse], error) { r := make(map[string]*attributespb.GetAttributeValuesByFqnsResponse_AttributeAndValue) - for _, fqn := range in.GetFqns() { + for _, fqn := range in.Msg.GetFqns() { av, err := NewAttributeValueFQN(fqn) if err != nil { slog.Error("invalid fqn", "notfqn", fqn, "error", err) return nil, status.New(codes.InvalidArgument, fmt.Sprintf("invalid attribute fqn [%s]", fqn)).Err() } v := mockValueFor(av) + for i := range v.GetGrants() { + v.Grants[i].Uri = f.s.kasTestURLLookup[v.GetGrants()[i].GetUri()] + } r[fqn] = &attributespb.GetAttributeValuesByFqnsResponse_AttributeAndValue{ Attribute: v.GetAttribute(), Value: v, } } - return &attributespb.GetAttributeValuesByFqnsResponse{FqnAttributeValues: r}, nil + return connect.NewResponse(&attributespb.GetAttributeValuesByFqnsResponse{FqnAttributeValues: r}), nil } type FakeKASRegistry struct { - kasregistry.UnimplementedKeyAccessServerRegistryServiceServer + kasregistryconnect.UnimplementedKeyAccessServerRegistryServiceHandler + s *TDFSuite kases []struct { url, private, public string } } -func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) { +func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *connect.Request[kasregistry.ListKeyAccessServersRequest]) (*connect.Response[kasregistry.ListKeyAccessServersResponse], error) { resp := &kasregistry.ListKeyAccessServersResponse{ KeyAccessServers: make([]*policy.KeyAccessServer, 0, len(f.kases)), } for _, k := range f.kases { kas := &policy.KeyAccessServer{ - Uri: k.url, + Uri: f.s.kasTestURLLookup[k.url], } resp.KeyAccessServers = append(resp.KeyAccessServers, kas) } - return resp, nil + return connect.NewResponse(resp), nil } type FakeKas struct { - kaspb.UnimplementedAccessServiceServer + kasconnect.UnimplementedAccessServiceHandler KASInfo privateKey string s *TDFSuite legakeys map[string]keyInfo } -func (f *FakeKas) Rewrap(_ context.Context, in *kaspb.RewrapRequest) (*kaspb.RewrapResponse, error) { - signedRequestToken := in.GetSignedRequestToken() +func (f *FakeKas) Rewrap(_ context.Context, in *connect.Request[kaspb.RewrapRequest]) (*connect.Response[kaspb.RewrapResponse], error) { + signedRequestToken := in.Msg.GetSignedRequestToken() token, err := jwt.ParseInsecure([]byte(signedRequestToken)) if err != nil { @@ -2046,11 +2049,11 @@ func (f *FakeKas) Rewrap(_ context.Context, in *kaspb.RewrapRequest) (*kaspb.Rew } result := f.getRewrapResponse(requestBodyStr) - return result, nil + return connect.NewResponse(result), nil } -func (f *FakeKas) PublicKey(_ context.Context, _ *kaspb.PublicKeyRequest) (*kaspb.PublicKeyResponse, error) { - return &kaspb.PublicKeyResponse{PublicKey: f.KASInfo.PublicKey, Kid: f.KID}, nil +func (f *FakeKas) PublicKey(_ context.Context, _ *connect.Request[kaspb.PublicKeyRequest]) (*connect.Response[kaspb.PublicKeyResponse], error) { + return connect.NewResponse(&kaspb.PublicKeyResponse{PublicKey: f.KASInfo.PublicKey, Kid: f.KID}), nil } func (f *FakeKas) getRewrapResponse(rewrapRequest string) *kaspb.RewrapResponse { diff --git a/service/authorization/authorization_test.go b/service/authorization/authorization_test.go index 5b69bdc72d..082448b0ae 100644 --- a/service/authorization/authorization_test.go +++ b/service/authorization/authorization_test.go @@ -1,9 +1,7 @@ package authorization import ( - "context" "errors" - "fmt" "log/slog" "strings" "testing" @@ -16,118 +14,16 @@ import ( "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" attr "github.com/opentdf/platform/protocol/go/policy/attributes" - sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" otdf "github.com/opentdf/platform/sdk" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" ) -var ( - getAttributesByValueFqnsResponse attr.GetAttributeValuesByFqnsResponse - errGetAttributesByValueFqns error - listAttributeResp attr.ListAttributesResponse - errListAttributes error - listSubjectMappings sm.ListSubjectMappingsResponse - createEntityChainResp entityresolution.CreateEntityChainFromJwtResponse - resolveEntitiesResp entityresolution.ResolveEntitiesResponse - mockNamespace = "www.example.org" - mockAttrName = "foo" - mockAttrValue1 = "value1" - mockAttrValue2 = "value2" - mockAttrValue3 = "value3" - mockFqn1 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue1) - mockFqn2 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue2) - mockFqn3 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue3) -) - -type myAttributesClient struct { - attr.AttributesServiceClient -} - -func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { - return &listAttributeResp, errListAttributes -} - -func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attr.GetAttributeValuesByFqnsResponse, error) { - return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns -} - -type myERSClient struct { - entityresolution.EntityResolutionServiceClient -} - -type mySubjectMappingClient struct { - sm.SubjectMappingServiceClient -} - -type paginatedMockSubjectMappingClient struct { - sm.SubjectMappingServiceClient -} - -func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { - return &listSubjectMappings, nil -} - -func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest, _ ...grpc.CallOption) (*entityresolution.CreateEntityChainFromJwtResponse, error) { - return &createEntityChainResp, nil -} - -func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest, _ ...grpc.CallOption) (*entityresolution.ResolveEntitiesResponse, error) { - return &resolveEntitiesResp, nil -} - -var ( - smPaginationOffset = 3 - smListCallCount = 0 -) - -func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { - smListCallCount++ - // simulate paginated list and policy LIST behavior - if smPaginationOffset > 0 { - rsp := &sm.ListSubjectMappingsResponse{ - SubjectMappings: nil, - Pagination: &policy.PageResponse{ - NextOffset: int32(smPaginationOffset), - }, - } - smPaginationOffset = 0 - return rsp, nil - } - return &listSubjectMappings, nil -} - -type paginatedMockAttributesClient struct { - attr.AttributesServiceClient -} - -var ( - attrPaginationOffset = 3 - attrListCallCount = 0 -) - -func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { - attrListCallCount++ - // simulate paginated list and policy LIST behavior - if attrPaginationOffset > 0 { - rsp := &attr.ListAttributesResponse{ - Attributes: nil, - Pagination: &policy.PageResponse{ - NextOffset: int32(attrPaginationOffset), - }, - } - attrPaginationOffset = 0 - return rsp, nil - } - return &listAttributeResp, nil -} - func TestGetComprehensiveHierarchy(t *testing.T) { as := &AuthorizationService{ logger: logger.CreateTestLogger(), @@ -448,9 +344,11 @@ func Test_GetDecisions_AllOf_Fail(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -546,9 +444,11 @@ func Test_GetDecisionsAllOfWithEnvironmental_Pass(t *testing.T) { } as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -641,9 +541,11 @@ func Test_GetDecisionsAllOfWithEnvironmental_Fail(t *testing.T) { } as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -714,9 +616,11 @@ func Test_GetEntitlementsSimple(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -786,9 +690,11 @@ func Test_GetEntitlementsFqnCasing(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -863,7 +769,8 @@ func Test_GetEntitlements_HandlesPagination(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ + logger: logger, + sdk: &otdf.SDK{ SubjectMapping: &paginatedMockSubjectMappingClient{}, Attributes: &paginatedMockAttributesClient{}, EntityResoution: &myERSClient{}, @@ -954,9 +861,11 @@ func Test_GetEntitlementsWithComprehensiveHierarchy(t *testing.T) { prepared, err := rego.PrepareForEval(t.Context()) require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -1194,9 +1103,11 @@ func Test_GetDecisions_RA_FQN_Edge_Cases(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -1400,9 +1311,11 @@ func Test_GetDecisionsAllOf_Pass_EC_RA_Length_Mismatch(t *testing.T) { } as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -1677,9 +1590,11 @@ func Test_GetDecisions_Empty_EC_RA(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), diff --git a/service/authorization/authorization_test_structures.go b/service/authorization/authorization_test_structures.go new file mode 100644 index 0000000000..e9f4d764f5 --- /dev/null +++ b/service/authorization/authorization_test_structures.go @@ -0,0 +1,334 @@ +package authorization + +import ( + "context" + "fmt" + + "github.com/opentdf/platform/protocol/go/entityresolution" + "github.com/opentdf/platform/protocol/go/policy" + attr "github.com/opentdf/platform/protocol/go/policy/attributes" + sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" +) + +var ( + getAttributesByValueFqnsResponse attr.GetAttributeValuesByFqnsResponse + errGetAttributesByValueFqns error + listAttributeResp attr.ListAttributesResponse + errListAttributes error + listSubjectMappings sm.ListSubjectMappingsResponse + createEntityChainResp entityresolution.CreateEntityChainFromJwtResponse + resolveEntitiesResp entityresolution.ResolveEntitiesResponse + mockNamespace = "www.example.org" + mockAttrName = "foo" + mockAttrValue1 = "value1" + mockAttrValue2 = "value2" + mockAttrValue3 = "value3" + mockFqn1 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue1) + mockFqn2 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue2) + mockFqn3 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue3) +) + +// //// Mock attributes client for testing ///// +type myAttributesClient struct{} + +func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { + return &listAttributeResp, errListAttributes +} + +func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest) (*attr.GetAttributeValuesByFqnsResponse, error) { + return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns +} + +func (*myAttributesClient) ListAttributeValues(_ context.Context, _ *attr.ListAttributeValuesRequest) (*attr.ListAttributeValuesResponse, error) { + return &attr.ListAttributeValuesResponse{}, nil +} + +func (*myAttributesClient) GetAttribute(_ context.Context, _ *attr.GetAttributeRequest) (*attr.GetAttributeResponse, error) { + return &attr.GetAttributeResponse{}, nil +} + +func (*myAttributesClient) GetAttributeValue(_ context.Context, _ *attr.GetAttributeValueRequest) (*attr.GetAttributeValueResponse, error) { + return &attr.GetAttributeValueResponse{}, nil +} + +func (*myAttributesClient) CreateAttribute(_ context.Context, _ *attr.CreateAttributeRequest) (*attr.CreateAttributeResponse, error) { + return &attr.CreateAttributeResponse{}, nil +} + +func (*myAttributesClient) UpdateAttribute(_ context.Context, _ *attr.UpdateAttributeRequest) (*attr.UpdateAttributeResponse, error) { + return &attr.UpdateAttributeResponse{}, nil +} + +func (*myAttributesClient) DeactivateAttribute(_ context.Context, _ *attr.DeactivateAttributeRequest) (*attr.DeactivateAttributeResponse, error) { + return &attr.DeactivateAttributeResponse{}, nil +} + +func (*myAttributesClient) CreateAttributeValue(_ context.Context, _ *attr.CreateAttributeValueRequest) (*attr.CreateAttributeValueResponse, error) { + return &attr.CreateAttributeValueResponse{}, nil +} + +func (*myAttributesClient) UpdateAttributeValue(_ context.Context, _ *attr.UpdateAttributeValueRequest) (*attr.UpdateAttributeValueResponse, error) { + return &attr.UpdateAttributeValueResponse{}, nil +} + +func (*myAttributesClient) DeactivateAttributeValue(_ context.Context, _ *attr.DeactivateAttributeValueRequest) (*attr.DeactivateAttributeValueResponse, error) { + return &attr.DeactivateAttributeValueResponse{}, nil +} + +func (*myAttributesClient) AssignKeyAccessServerToAttribute(_ context.Context, _ *attr.AssignKeyAccessServerToAttributeRequest) (*attr.AssignKeyAccessServerToAttributeResponse, error) { + return &attr.AssignKeyAccessServerToAttributeResponse{}, nil +} + +func (*myAttributesClient) RemoveKeyAccessServerFromAttribute(_ context.Context, _ *attr.RemoveKeyAccessServerFromAttributeRequest) (*attr.RemoveKeyAccessServerFromAttributeResponse, error) { + return &attr.RemoveKeyAccessServerFromAttributeResponse{}, nil +} + +func (*myAttributesClient) AssignKeyAccessServerToValue(_ context.Context, _ *attr.AssignKeyAccessServerToValueRequest) (*attr.AssignKeyAccessServerToValueResponse, error) { + return &attr.AssignKeyAccessServerToValueResponse{}, nil +} + +func (*myAttributesClient) RemoveKeyAccessServerFromValue(_ context.Context, _ *attr.RemoveKeyAccessServerFromValueRequest) (*attr.RemoveKeyAccessServerFromValueResponse, error) { + return &attr.RemoveKeyAccessServerFromValueResponse{}, nil +} + +func (*myAttributesClient) AssignPublicKeyToAttribute(_ context.Context, _ *attr.AssignPublicKeyToAttributeRequest) (*attr.AssignPublicKeyToAttributeResponse, error) { + return &attr.AssignPublicKeyToAttributeResponse{}, nil +} + +func (*myAttributesClient) RemovePublicKeyFromAttribute(_ context.Context, _ *attr.RemovePublicKeyFromAttributeRequest) (*attr.RemovePublicKeyFromAttributeResponse, error) { + return &attr.RemovePublicKeyFromAttributeResponse{}, nil +} + +func (*myAttributesClient) AssignPublicKeyToValue(_ context.Context, _ *attr.AssignPublicKeyToValueRequest) (*attr.AssignPublicKeyToValueResponse, error) { + return &attr.AssignPublicKeyToValueResponse{}, nil +} + +func (*myAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.RemovePublicKeyFromValueRequest) (*attr.RemovePublicKeyFromValueResponse, error) { + return &attr.RemovePublicKeyFromValueResponse{}, nil +} + +// // Mock ERS Client for testing ///// +type myERSClient struct{} + +func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + return &createEntityChainResp, nil +} + +func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { + return &resolveEntitiesResp, nil +} + +// // Mock Subject Mapping Client for testing ///// +type mySubjectMappingClient struct{} + +func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { + return &listSubjectMappings, nil +} + +func (*mySubjectMappingClient) MatchSubjectMappings(_ context.Context, _ *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { + return &sm.MatchSubjectMappingsResponse{}, nil +} + +func (*mySubjectMappingClient) GetSubjectMapping(_ context.Context, _ *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { + return &sm.GetSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) CreateSubjectMapping(_ context.Context, _ *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { + return &sm.CreateSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) UpdateSubjectMapping(_ context.Context, _ *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { + return &sm.UpdateSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) DeleteSubjectMapping(_ context.Context, _ *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { + return &sm.DeleteSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) ListSubjectConditionSets(_ context.Context, _ *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { + return &sm.ListSubjectConditionSetsResponse{}, nil +} + +func (*mySubjectMappingClient) GetSubjectConditionSet(_ context.Context, _ *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { + return &sm.GetSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) CreateSubjectConditionSet(_ context.Context, _ *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { + return &sm.CreateSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) UpdateSubjectConditionSet(_ context.Context, _ *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { + return &sm.UpdateSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) DeleteSubjectConditionSet(_ context.Context, _ *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { + return &sm.DeleteSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(_ context.Context, _ *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { + return &sm.DeleteAllUnmappedSubjectConditionSetsResponse{}, nil +} + +// // Mock paginated Subject Mapping Client for testing ///// +type paginatedMockSubjectMappingClient struct{} + +var ( + smPaginationOffset = 3 + smListCallCount = 0 +) + +func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { + smListCallCount++ + // simulate paginated list and policy LIST behavior + if smPaginationOffset > 0 { + rsp := &sm.ListSubjectMappingsResponse{ + SubjectMappings: nil, + Pagination: &policy.PageResponse{ + NextOffset: int32(smPaginationOffset), + }, + } + smPaginationOffset = 0 + return rsp, nil + } + return &listSubjectMappings, nil +} + +func (*paginatedMockSubjectMappingClient) MatchSubjectMappings(_ context.Context, _ *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { + return &sm.MatchSubjectMappingsResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) GetSubjectMapping(_ context.Context, _ *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { + return &sm.GetSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) CreateSubjectMapping(_ context.Context, _ *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { + return &sm.CreateSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) UpdateSubjectMapping(_ context.Context, _ *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { + return &sm.UpdateSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) DeleteSubjectMapping(_ context.Context, _ *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { + return &sm.DeleteSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) ListSubjectConditionSets(_ context.Context, _ *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { + return &sm.ListSubjectConditionSetsResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) GetSubjectConditionSet(_ context.Context, _ *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { + return &sm.GetSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) CreateSubjectConditionSet(_ context.Context, _ *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { + return &sm.CreateSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) UpdateSubjectConditionSet(_ context.Context, _ *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { + return &sm.UpdateSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) DeleteSubjectConditionSet(_ context.Context, _ *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { + return &sm.DeleteSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(_ context.Context, _ *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { + return &sm.DeleteAllUnmappedSubjectConditionSetsResponse{}, nil +} + +// // Mock paginated attributs client for testing //// +type paginatedMockAttributesClient struct{} + +var ( + attrPaginationOffset = 3 + attrListCallCount = 0 +) + +func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { + attrListCallCount++ + // simulate paginated list and policy LIST behavior + if attrPaginationOffset > 0 { + rsp := &attr.ListAttributesResponse{ + Attributes: nil, + Pagination: &policy.PageResponse{ + NextOffset: int32(attrPaginationOffset), + }, + } + attrPaginationOffset = 0 + return rsp, nil + } + return &listAttributeResp, nil +} + +func (*paginatedMockAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest) (*attr.GetAttributeValuesByFqnsResponse, error) { + return &attr.GetAttributeValuesByFqnsResponse{}, nil +} + +func (*paginatedMockAttributesClient) ListAttributeValues(_ context.Context, _ *attr.ListAttributeValuesRequest) (*attr.ListAttributeValuesResponse, error) { + return &attr.ListAttributeValuesResponse{}, nil +} + +func (*paginatedMockAttributesClient) GetAttribute(_ context.Context, _ *attr.GetAttributeRequest) (*attr.GetAttributeResponse, error) { + return &attr.GetAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) GetAttributeValue(_ context.Context, _ *attr.GetAttributeValueRequest) (*attr.GetAttributeValueResponse, error) { + return &attr.GetAttributeValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) CreateAttribute(_ context.Context, _ *attr.CreateAttributeRequest) (*attr.CreateAttributeResponse, error) { + return &attr.CreateAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) UpdateAttribute(_ context.Context, _ *attr.UpdateAttributeRequest) (*attr.UpdateAttributeResponse, error) { + return &attr.UpdateAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) DeactivateAttribute(_ context.Context, _ *attr.DeactivateAttributeRequest) (*attr.DeactivateAttributeResponse, error) { + return &attr.DeactivateAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) CreateAttributeValue(_ context.Context, _ *attr.CreateAttributeValueRequest) (*attr.CreateAttributeValueResponse, error) { + return &attr.CreateAttributeValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) UpdateAttributeValue(_ context.Context, _ *attr.UpdateAttributeValueRequest) (*attr.UpdateAttributeValueResponse, error) { + return &attr.UpdateAttributeValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) DeactivateAttributeValue(_ context.Context, _ *attr.DeactivateAttributeValueRequest) (*attr.DeactivateAttributeValueResponse, error) { + return &attr.DeactivateAttributeValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) AssignKeyAccessServerToAttribute(_ context.Context, _ *attr.AssignKeyAccessServerToAttributeRequest) (*attr.AssignKeyAccessServerToAttributeResponse, error) { + return &attr.AssignKeyAccessServerToAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) RemoveKeyAccessServerFromAttribute(_ context.Context, _ *attr.RemoveKeyAccessServerFromAttributeRequest) (*attr.RemoveKeyAccessServerFromAttributeResponse, error) { + return &attr.RemoveKeyAccessServerFromAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) AssignKeyAccessServerToValue(_ context.Context, _ *attr.AssignKeyAccessServerToValueRequest) (*attr.AssignKeyAccessServerToValueResponse, error) { + return &attr.AssignKeyAccessServerToValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) RemoveKeyAccessServerFromValue(_ context.Context, _ *attr.RemoveKeyAccessServerFromValueRequest) (*attr.RemoveKeyAccessServerFromValueResponse, error) { + return &attr.RemoveKeyAccessServerFromValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) AssignPublicKeyToAttribute(_ context.Context, _ *attr.AssignPublicKeyToAttributeRequest) (*attr.AssignPublicKeyToAttributeResponse, error) { + return &attr.AssignPublicKeyToAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) RemovePublicKeyFromAttribute(_ context.Context, _ *attr.RemovePublicKeyFromAttributeRequest) (*attr.RemovePublicKeyFromAttributeResponse, error) { + return &attr.RemovePublicKeyFromAttributeResponse{}, nil +} + +func (*paginatedMockAttributesClient) AssignPublicKeyToValue(_ context.Context, _ *attr.AssignPublicKeyToValueRequest) (*attr.AssignPublicKeyToValueResponse, error) { + return &attr.AssignPublicKeyToValueResponse{}, nil +} + +func (*paginatedMockAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.RemovePublicKeyFromValueRequest) (*attr.RemovePublicKeyFromValueResponse, error) { + return &attr.RemovePublicKeyFromValueResponse{}, nil +} diff --git a/service/internal/server/memhttp/listener.go b/service/internal/server/memhttp/listener.go index 1ec303ddba..cf1d381647 100644 --- a/service/internal/server/memhttp/listener.go +++ b/service/internal/server/memhttp/listener.go @@ -56,4 +56,4 @@ func (*memoryAddr) Network() string { return "memory" } // String implements io.Stringer, returning a value that matches the // certificates used by net/http/httptest. -func (*memoryAddr) String() string { return "opentdf.io" } +func (*memoryAddr) String() string { return "http://opentdf.io" } diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 816764f288..46748b529b 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -19,6 +19,7 @@ import ( "connectrpc.com/validate" "github.com/go-chi/cors" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/opentdf/platform/sdk" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" "github.com/opentdf/platform/service/internal/security" @@ -134,6 +135,7 @@ type OpenTDFServer struct { // To Deprecate: Use the TrustKeyIndex and TrustKeyManager instead CryptoProvider *security.StandardCrypto + Listener net.Listener logger *logger.Logger } @@ -440,11 +442,13 @@ func (s OpenTDFServer) Start() error { s.ConnectRPCInProcess.Mux.Handle(grpcreflect.NewHandlerV1(reflector)) s.ConnectRPCInProcess.Mux.Handle(grpcreflect.NewHandlerV1Alpha(reflector)) - // Start Http Server ln, err := s.openHTTPServerPort() if err != nil { return err } + s.Listener = ln + + // Start Http Server go s.startHTTPServer(ln) return nil @@ -458,6 +462,10 @@ func (s OpenTDFServer) Stop() { s.logger.Error("failed to shutdown http server", slog.String("error", err.Error())) return } + // Close the listener + if s.Listener != nil { + s.Listener.Close() + } s.logger.Info("shutting down in process grpc server") if err := s.ConnectRPCInProcess.srv.Shutdown(ctx); err != nil { @@ -468,7 +476,25 @@ func (s OpenTDFServer) Stop() { s.logger.Info("shutdown complete") } -func (s inProcessServer) Conn() *grpc.ClientConn { +func (s inProcessServer) Conn() *sdk.ConnectRPCConnection { + var clientInterceptors []connect.Interceptor + + // Add audit interceptor + clientInterceptors = append(clientInterceptors, sdkAudit.MetadataAddingConnectInterceptor()) + + conn := sdk.ConnectRPCConnection{ + Client: s.srv.Client(), + Endpoint: s.srv.Listener.Addr().String(), + Options: []connect.ClientOption{ + connect.WithInterceptors(clientInterceptors...), + connect.WithReadMaxBytes(s.maxCallRecvMsgSize), + connect.WithSendMaxBytes(s.maxCallSendMsgSize), + }, + } + return &conn +} + +func (s inProcessServer) GrpcConn() *grpc.ClientConn { var clientInterceptors []grpc.UnaryClientInterceptor // Add audit interceptor diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index e6be96c890..e8afac538f 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -120,7 +120,9 @@ func registerCoreServices(reg serviceregistry.Registry, mode []string) ([]string // based on the configuration and namespace mode. It creates a new service logger // and a database client if required. It registers the services with the gRPC server, // in-process gRPC server, and gRPC gateway. Finally, it logs the status of each service. -func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDFServer, client *sdk.SDK, logger *logging.Logger, reg serviceregistry.Registry) error { +func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDFServer, client *sdk.SDK, logger *logging.Logger, reg serviceregistry.Registry) (func(), error) { + var gatewayCleanup func() + // Iterate through the registered namespaces for ns, namespace := range reg { // modeEnabled checks if the mode is enabled based on the configuration and namespace mode. @@ -165,7 +167,7 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF var err error svcDBClient, err = newServiceDBClient(ctx, cfg.Logger, cfg.DB, tracer, ns, svc.DBMigrations()) if err != nil { - return err + return func() {}, err } } if svc.GetVersion() != "" { @@ -183,11 +185,11 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF Tracer: tracer, }) if err != nil { - return err + return func() {}, err } if err := svc.RegisterConfigUpdateHook(ctx, cfg.AddOnConfigChangeHook); err != nil { - return fmt.Errorf("failed to register config update hook: %w", err) + return func() {}, fmt.Errorf("failed to register config update hook: %w", err) } // Register Connect RPC Services @@ -201,8 +203,18 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF } // Register GRPC Gateway Handler using the in-process connect rpc - if err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, otdf.ConnectRPCInProcess.Conn()); err != nil { + grpcConn := otdf.ConnectRPCInProcess.GrpcConn() + err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, otdf.ConnectRPCInProcess.GrpcConn()) + if err != nil { logger.Info("service did not register a grpc gateway handler", slog.String("namespace", ns)) + } else if gatewayCleanup == nil { + gatewayCleanup = func() { + slog.Debug("executing cleanup") + if grpcConn != nil { + grpcConn.Close() + } + slog.Info("cleanup complete") + } } // Register Extra Handlers @@ -222,7 +234,10 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF } } - return nil + if gatewayCleanup == nil { + gatewayCleanup = func() {} + } + return gatewayCleanup, nil } func extractServiceLoggerConfig(cfg config.ServiceConfig) (string, error) { diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index cb9c687067..e7c19cd08f 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -257,7 +257,7 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { newLogger, err := logger.NewLogger(logger.Config{Output: "stdout", Level: "info", Type: "json"}) suite.Require().NoError(err) - err = startServices(ctx, &config.Config{ + cleanup, err := startServices(ctx, &config.Config{ Mode: []string{"test"}, Logger: logger.Config{Output: "stdout", Level: "info", Type: "json"}, // DB: db.Config{ @@ -274,6 +274,10 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { "foobar": {}, }, }, otdf, nil, newLogger, registry) + + // call cleanup function + defer cleanup() + suite.Require().NoError(err) // require.NotNil(t, cF) // assert.Lenf(t, services, 2, "expected 2 services enabled, got %d", len(services)) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index cea8f3c33f..7e36b44399 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -6,13 +6,12 @@ import ( "errors" "fmt" "log/slog" - "net" - "net/url" "os" "os/signal" "slices" "syscall" + "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk" sdkauth "github.com/opentdf/platform/sdk/auth" @@ -25,9 +24,6 @@ import ( "github.com/opentdf/platform/service/pkg/serviceregistry" "github.com/opentdf/platform/service/tracing" wellknown "github.com/opentdf/platform/service/wellknownconfiguration" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" ) const devModeMessage = ` @@ -231,18 +227,19 @@ func Start(f ...StartOptions) error { return errors.New("entityresolution endpoint must be provided in core mode") } - ersDialOptions := []grpc.DialOption{} + ersConnectRPCConn := sdk.ConnectRPCConnection{} + var tlsConfig *tls.Config if cfg.SDKConfig.EntityResolutionConnection.Insecure { tlsConfig = &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, // #nosec G402 } - ersDialOptions = append(ersDialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + ersConnectRPCConn.Client = httputil.SafeHTTPClientWithTLSConfig(tlsConfig) } if cfg.SDKConfig.EntityResolutionConnection.Plaintext { tlsConfig = &tls.Config{} - ersDialOptions = append(ersDialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) + ersConnectRPCConn.Client = httputil.SafeHTTPClient() } if cfg.SDKConfig.ClientID != "" && cfg.SDKConfig.ClientSecret != "" { @@ -268,30 +265,16 @@ func Start(f ...StartOptions) error { interceptor := sdkauth.NewTokenAddingInterceptorWithClient(ts, httputil.SafeHTTPClientWithTLSConfig(tlsConfig)) - ersDialOptions = append(ersDialOptions, grpc.WithChainUnaryInterceptor(interceptor.AddCredentials)) + ersConnectRPCConn.Options = append(ersConnectRPCConn.Options, connect.WithInterceptors(interceptor.AddCredentialsConnect())) } - parsedURL, err := url.Parse(cfg.SDKConfig.EntityResolutionConnection.Endpoint) - if err != nil { - return fmt.Errorf("cannot parse ers url(%s): %w", cfg.SDKConfig.EntityResolutionConnection.Endpoint, err) - } - // Needed to support buffconn for testing - if parsedURL.Host == "" { - return errors.New("ERS host is empty when parsing") + if sdk.IsPlatformEndpointMalformed(cfg.SDKConfig.EntityResolutionConnection.Endpoint) { + return fmt.Errorf("entityresolution endpoint is malformed: %s", cfg.SDKConfig.EntityResolutionConnection.Endpoint) } - port := parsedURL.Port() - // if port is empty, default to 443. - if port == "" { - port = "443" - } - ersGRPCEndpoint := net.JoinHostPort(parsedURL.Hostname(), port) + ersConnectRPCConn.Endpoint = cfg.SDKConfig.EntityResolutionConnection.Endpoint - conn, err := grpc.NewClient(ersGRPCEndpoint, ersDialOptions...) - if err != nil { - return fmt.Errorf("could not connect to ERS: %w", err) - } - sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(conn)) - logger.Info("added with custom ers connection for ", "", ersGRPCEndpoint) + sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(&ersConnectRPCConn)) + logger.Info("added with custom ers connection for ", "", ersConnectRPCConn.Endpoint) } client, err = sdk.New("", sdkOptions...) @@ -317,11 +300,12 @@ func Start(f ...StartOptions) error { defer client.Close() logger.Info("starting services") - err = startServices(ctx, cfg, otdf, client, logger, svcRegistry) + gatewayCleanup, err := startServices(ctx, cfg, otdf, client, logger, svcRegistry) if err != nil { logger.Error("issue starting services", slog.String("error", err.Error())) return fmt.Errorf("issue starting services: %w", err) } + defer gatewayCleanup() // Start watching the configuration for changes with registered config change service hooks if err := cfg.Watch(ctx); err != nil { diff --git a/service/pkg/server/start_test.go b/service/pkg/server/start_test.go index 67c1b8f92b..8db33168f2 100644 --- a/service/pkg/server/start_test.go +++ b/service/pkg/server/start_test.go @@ -256,13 +256,14 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered() { suite.Require().NoError(err) // Start services with test service - err = startServices(context.Background(), &config.Config{ + cleanup, err := startServices(context.Background(), &config.Config{ Mode: tc.mode, Services: map[string]config.ServiceConfig{ "test": {}, }, }, s, nil, logger, registry) require.NoError(t, err) + defer cleanup() require.NoError(t, s.Start()) defer s.Stop() diff --git a/service/rttests/rt_test.go b/service/rttests/rt_test.go index 1f2af56e80..64ef90331c 100644 --- a/service/rttests/rt_test.go +++ b/service/rttests/rt_test.go @@ -28,10 +28,11 @@ import ( // then those will need to be updated. type TestConfig struct { - PlatformEndpoint string - TokenEndpoint string - ClientID string - ClientSecret string + PlatformEndpoint string + PlatformEndpointWithScheme string + TokenEndpoint string + ClientID string + ClientSecret string } var attributesToMap = []string{ @@ -111,11 +112,14 @@ func (s *RoundtripSuite) SetupSuite() { opts := []sdk.Option{} if os.Getenv("TLS_ENABLED") == "" { opts = append(opts, sdk.WithInsecurePlaintextConn()) + s.TestConfig.PlatformEndpointWithScheme = "http://" + s.TestConfig.PlatformEndpoint + } else { + s.TestConfig.PlatformEndpointWithScheme = "https://" + s.TestConfig.PlatformEndpoint } opts = append(opts, sdk.WithClientCredentials(s.TestConfig.ClientID, s.TestConfig.ClientSecret, nil)) - sdk, err := sdk.New("http://"+s.TestConfig.PlatformEndpoint, opts...) + sdk, err := sdk.New(s.TestConfig.PlatformEndpointWithScheme, opts...) s.Require().NoError(err) s.client = sdk @@ -342,8 +346,7 @@ func encrypt(client *sdk.SDK, testConfig TestConfig, plaintext string, attribute sdk.WithDataAttributes(attributes...), sdk.WithKasInformation( sdk.KASInfo{ - // examples assume insecure http - URL: "http://" + testConfig.PlatformEndpoint, + URL: testConfig.PlatformEndpointWithScheme, PublicKey: "", })) if err != nil { diff --git a/test/tdf-roundtrips.bats b/test/tdf-roundtrips.bats index 149557ed86..efee0c0b49 100755 --- a/test/tdf-roundtrips.bats +++ b/test/tdf-roundtrips.bats @@ -6,10 +6,10 @@ @test "examples: roundtrip Z-TDF" { # TODO: add subject mapping here to remove reliance on `provision fixtures` echo "[INFO] configure attribute with grant for local kas" - go run ./examples --creds opentdf:secret kas add --kas http://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" + go run ./examples --creds opentdf:secret kas add --kas https://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 -v value1 go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 - go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k http://localhost:8080 + go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k https://localhost:8080 echo "[INFO] create a tdf3 format file" run go run ./examples encrypt "Hello Zero Trust" @@ -58,11 +58,11 @@ @test "examples: roundtrip Z-TDF with extra unnecessary, invalid kas" { # TODO: add subject mapping here to remove reliance on `provision fixtures` echo "[INFO] configure attribute with grant for local kas" - go run ./examples --creds opentdf:secret kas add --kas http://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" + go run ./examples --creds opentdf:secret kas add --kas https://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" go run ./examples --creds opentdf:secret kas add --kas http://localhost:9090 --algorithm "rsa:2048" --kid r2 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 -v value1 go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 - go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k "http://localhost:8080,http://localhost:9090" + go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k "https://localhost:8080,http://localhost:9090" echo "[INFO] create a tdf3 format file" run go run ./examples encrypt "Hello multikao split"