From 07b3fafe9c48d1bb5355975fc0e8d5cf57a768d8 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 9 May 2025 15:20:18 -0400 Subject: [PATCH 01/31] connect rpc client side --- examples/cmd/attributes.go | 91 ++++---- examples/cmd/authorization.go | 8 +- examples/cmd/benchmark_decision.go | 7 +- examples/cmd/decrypt.go | 1 - examples/cmd/kas.go | 28 ++- examples/go.mod | 1 + examples/go.sum | 2 + sdk/audit/metadata_adding_interceptor.go | 31 +++ sdk/auth/token_adding_interceptor.go | 33 +++ sdk/bulk.go | 2 +- sdk/go.mod | 1 + sdk/go.sum | 2 + sdk/granter.go | 10 +- sdk/granter_test.go | 17 +- sdk/kas_client.go | 54 +++-- sdk/kas_client_test.go | 8 +- sdk/nanotdf.go | 2 +- sdk/nanotdf_test.go | 2 +- sdk/options.go | 67 +++--- sdk/sdk.go | 224 +++++++------------- sdk/sdk_test.go | 137 +++++------- sdk/tdf.go | 28 +-- sdk/tdf_test.go | 200 ++++++++--------- service/authorization/authorization.go | 30 +-- service/authorization/authorization_test.go | 43 ++-- service/internal/server/memhttp/listener.go | 2 +- service/internal/server/server.go | 34 ++- service/kas/access/accessPdp.go | 5 +- service/pkg/server/services.go | 4 +- service/pkg/server/start.go | 63 ++---- service/rttests/rt_test.go | 43 ++-- 31 files changed, 582 insertions(+), 598 deletions(-) diff --git a/examples/cmd/attributes.go b/examples/cmd/attributes.go index 7720d4ab7c..ca78eb8ca0 100644 --- a/examples/cmd/attributes.go +++ b/examples/cmd/attributes.go @@ -9,6 +9,7 @@ import ( "regexp" "strings" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" @@ -105,19 +106,18 @@ func listAttributes(cmd *cobra.Command) error { slog.Error("could not connect", slog.Any("error", err)) return err } - defer s.Close() ctx := cmd.Context() var nsuris []string if ns == "" { slog.Info("listing namespaces") - listResp, err := s.Namespaces.ListNamespaces(ctx, &namespaces.ListNamespacesRequest{}) + listResp, err := s.Namespaces.ListNamespaces(ctx, connect.NewRequest(&namespaces.ListNamespacesRequest{})) if err != nil { return err } - slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Namespaces))) - for _, n := range listResp.GetNamespaces() { + slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Msg.Namespaces))) + for _, n := range listResp.Msg.GetNamespaces() { nsuris = append(nsuris, n.GetFqn()) } } else { @@ -128,15 +128,15 @@ func listAttributes(cmd *cobra.Command) error { if err != nil { return err } - lsr, err := s.Attributes.ListAttributes(ctx, &attributes.ListAttributesRequest{ + lsr, err := s.Attributes.ListAttributes(ctx, connect.NewRequest(&attributes.ListAttributesRequest{ // namespace here must be the namespace name Namespace: u.Host, - }) + })) if err != nil { return err } - slog.Info(fmt.Sprintf("found %d attributes in namespace", len(lsr.GetAttributes())), "ns", n) - for _, a := range lsr.GetAttributes() { + slog.Info(fmt.Sprintf("found %d attributes in namespace", len(lsr.Msg.GetAttributes())), "ns", n) + for _, a := range lsr.Msg.GetAttributes() { if longformat { fmt.Printf("%s\t%s\n", a.GetFqn(), a.GetId()) } else { @@ -160,12 +160,12 @@ func nsuuid(ctx context.Context, s *sdk.SDK, u string) (string, error) { slog.Error("namespace url.Parse", "err", err, "url", u) return "", errors.Join(err, ErrInvalidArgument) } - listResp, err := s.Namespaces.ListNamespaces(ctx, &namespaces.ListNamespacesRequest{}) + listResp, err := s.Namespaces.ListNamespaces(ctx, connect.NewRequest(&namespaces.ListNamespacesRequest{})) if err != nil { slog.Error("ListNamespaces", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - for _, n := range listResp.GetNamespaces() { + for _, n := range listResp.Msg.GetNamespaces() { if n.GetName() == url.Hostname() { return n.GetId(), nil } @@ -174,15 +174,15 @@ func nsuuid(ctx context.Context, s *sdk.SDK, u string) (string, error) { } func attruuid(ctx context.Context, s *sdk.SDK, nsu, fqn string) (string, error) { - resp, err := s.Attributes.ListAttributes(ctx, &attributes.ListAttributesRequest{ + resp, err := s.Attributes.ListAttributes(ctx, connect.NewRequest(&attributes.ListAttributesRequest{ Namespace: nsu, State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, - }) + })) if err != nil { slog.Error("ListAttributes", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - for _, a := range resp.GetAttributes() { + for _, a := range resp.Msg.GetAttributes() { if strings.ToLower(a.GetFqn()) == strings.ToLower(fqn) { return a.GetId(), nil } @@ -191,12 +191,12 @@ func attruuid(ctx context.Context, s *sdk.SDK, nsu, fqn string) (string, error) } func avuuid(ctx context.Context, s *sdk.SDK, auuid, vs string) (string, error) { - resp, err := s.Attributes.GetAttribute(ctx, &attributes.GetAttributeRequest{Id: auuid}) + resp, err := s.Attributes.GetAttribute(ctx, connect.NewRequest(&attributes.GetAttributeRequest{Id: auuid})) if err != nil { slog.Error("GetAttribute", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - for _, v := range resp.GetAttribute().GetValues() { + for _, v := range resp.Msg.GetAttribute().GetValues() { if strings.ToLower(v.GetValue()) == strings.ToLower(vs) { return v.GetId(), nil } @@ -210,12 +210,12 @@ func addNamespace(ctx context.Context, s *sdk.SDK, u string) (string, error) { slog.Error("url.Parse", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - resp, err := s.Namespaces.CreateNamespace(ctx, &namespaces.CreateNamespaceRequest{Name: url.Hostname()}) + resp, err := s.Namespaces.CreateNamespace(ctx, connect.NewRequest(&namespaces.CreateNamespaceRequest{Name: url.Hostname()})) if err != nil { slog.Error("CreateNamespace", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - return resp.Namespace.GetId(), nil + return resp.Msg.Namespace.GetId(), nil } func addAttribute(cmd *cobra.Command) error { @@ -224,7 +224,6 @@ func addAttribute(cmd *cobra.Command) error { slog.Error("newSDK", slog.Any("error", err)) return err } - defer s.Close() are := regexp.MustCompile(`^(https?://[\w./]+)/attr/([^/\s]*)$`) m := are.FindStringSubmatch(attr) @@ -259,7 +258,6 @@ func removeAttribute(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } - defer s.Close() are := regexp.MustCompile(`^(https?://[\w./]+)/attr/([^/\s]*)$`) m := are.FindStringSubmatch(attr) @@ -277,10 +275,10 @@ func removeAttribute(cmd *cobra.Command) error { } if len(values) == 0 { if unsafeBool { - resp, err := s.Unsafe.UnsafeDeleteAttribute(cmd.Context(), &unsafe.UnsafeDeleteAttributeRequest{ + resp, err := s.Unsafe.UnsafeDeleteAttribute(cmd.Context(), connect.NewRequest(&unsafe.UnsafeDeleteAttributeRequest{ Id: auuid, Fqn: strings.ToLower(attr), - }) + })) if err != nil { slog.Error("UnsafeDeleteAttribute", "err", err, "id", auuid) return err @@ -288,9 +286,9 @@ func removeAttribute(cmd *cobra.Command) error { slog.Info("deleted attribute", "attr", attr, "resp", resp) return nil } - resp, err := s.Attributes.DeactivateAttribute(cmd.Context(), &attributes.DeactivateAttributeRequest{ + resp, err := s.Attributes.DeactivateAttribute(cmd.Context(), connect.NewRequest(&attributes.DeactivateAttributeRequest{ Id: auuid, - }) + })) if err != nil { slog.Error("DeactivateAttribute", "err", err, "id", auuid) return err @@ -304,19 +302,19 @@ func removeAttribute(cmd *cobra.Command) error { return err } if unsafeBool { - r, err := s.Unsafe.UnsafeDeleteAttributeValue(cmd.Context(), &unsafe.UnsafeDeleteAttributeValueRequest{ + r, err := s.Unsafe.UnsafeDeleteAttributeValue(cmd.Context(), connect.NewRequest(&unsafe.UnsafeDeleteAttributeValueRequest{ Id: avu, Fqn: strings.ToLower(attr + "/value/" + url.PathEscape(v)), - }) + })) if err != nil { slog.Error("UnsafeDeleteAttributeValue", "err", err, "id", avu) return err } slog.Info("deactivated attribute value", "attr", attr, "value", v, "resp", r) } else { - r, err := s.Attributes.DeactivateAttributeValue(cmd.Context(), &attributes.DeactivateAttributeValueRequest{ + r, err := s.Attributes.DeactivateAttributeValue(cmd.Context(), connect.NewRequest(&attributes.DeactivateAttributeValueRequest{ Id: avu, - }) + })) if err != nil { slog.Error("DeactivateAttributeValue", "err", err, "id", avu) return err @@ -334,7 +332,6 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { slog.Error("could not connect", "err", err) return err } - defer s.Close() are := regexp.MustCompile(`^(https?://[\w./]+)/attr/([^/\s]*)$`) m := are.FindStringSubmatch(attr) @@ -372,11 +369,11 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { return fmt.Errorf("assign must take a `--kas` parameter") case len(values) == 0: // look up all kasids associated with the attribute - ar, err := s.Attributes.GetAttribute(cmd.Context(), &attributes.GetAttributeRequest{Id: auuid}) + ar, err := s.Attributes.GetAttribute(cmd.Context(), connect.NewRequest(&attributes.GetAttributeRequest{Id: auuid})) if err != nil { return err } - for _, b := range ar.Attribute.GetGrants() { + for _, b := range ar.Msg.Attribute.GetGrants() { kasids = append(kasids, b.GetId()) kasById[b.GetId()] = b.GetUri() } @@ -388,11 +385,11 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { if err != nil { return err } - ar, err := s.Attributes.GetAttributeValue(cmd.Context(), &attributes.GetAttributeValueRequest{Id: avu}) + ar, err := s.Attributes.GetAttributeValue(cmd.Context(), connect.NewRequest(&attributes.GetAttributeValueRequest{Id: avu})) if err != nil { return err } - for _, b := range ar.GetValue().GetGrants() { + for _, b := range ar.Msg.GetValue().GetGrants() { kasids = append(kasids, b.GetId()) kasById[b.GetId()] = b.GetUri() } @@ -401,27 +398,27 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { for _, kasid := range kasids { if len(values) == 0 { if assign { - r, err := s.Attributes.AssignKeyAccessServerToAttribute(cmd.Context(), &attributes.AssignKeyAccessServerToAttributeRequest{ + r, err := s.Attributes.AssignKeyAccessServerToAttribute(cmd.Context(), connect.NewRequest(&attributes.AssignKeyAccessServerToAttributeRequest{ AttributeKeyAccessServer: &attributes.AttributeKeyAccessServer{ AttributeId: auuid, KeyAccessServerId: kasid, }, - }) + })) if err != nil { return err } - cmd.Printf("successfully assigned all of [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.GetAttributeKeyAccessServer()) + cmd.Printf("successfully assigned all of [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.Msg.GetAttributeKeyAccessServer()) } else { - r, err := s.Attributes.RemoveKeyAccessServerFromAttribute(cmd.Context(), &attributes.RemoveKeyAccessServerFromAttributeRequest{ + r, err := s.Attributes.RemoveKeyAccessServerFromAttribute(cmd.Context(), connect.NewRequest(&attributes.RemoveKeyAccessServerFromAttributeRequest{ AttributeKeyAccessServer: &attributes.AttributeKeyAccessServer{ AttributeId: auuid, KeyAccessServerId: kasid, }, - }) + })) if err != nil { return err } - cmd.Printf("successfully unassigned [%s] from [%s] (binding %v)\n", attr, kasById[kasid], *r.GetAttributeKeyAccessServer()) + cmd.Printf("successfully unassigned [%s] from [%s] (binding %v)\n", attr, kasById[kasid], *r.Msg.GetAttributeKeyAccessServer()) } } else { for _, v := range values { @@ -430,27 +427,27 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { return err } if assign { - r, err := s.Attributes.AssignKeyAccessServerToValue(cmd.Context(), &attributes.AssignKeyAccessServerToValueRequest{ + r, err := s.Attributes.AssignKeyAccessServerToValue(cmd.Context(), connect.NewRequest(&attributes.AssignKeyAccessServerToValueRequest{ ValueKeyAccessServer: &attributes.ValueKeyAccessServer{ ValueId: avu, KeyAccessServerId: kasid, }, - }) + })) if err != nil { return err } - cmd.Printf("successfully assigned [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.GetValueKeyAccessServer()) + cmd.Printf("successfully assigned [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.Msg.GetValueKeyAccessServer()) } else { - r, err := s.Attributes.RemoveKeyAccessServerFromValue(cmd.Context(), &attributes.RemoveKeyAccessServerFromValueRequest{ + r, err := s.Attributes.RemoveKeyAccessServerFromValue(cmd.Context(), connect.NewRequest(&attributes.RemoveKeyAccessServerFromValueRequest{ ValueKeyAccessServer: &attributes.ValueKeyAccessServer{ ValueId: avu, KeyAccessServerId: kasid, }, - }) + })) if err != nil { return err } - cmd.Printf("successfully unassigned [%s] from [%s] (binding [%v])\n", attr, kasById[kasid], *r.GetValueKeyAccessServer()) + cmd.Printf("successfully unassigned [%s] from [%s] (binding [%v])\n", attr, kasById[kasid], *r.Msg.GetValueKeyAccessServer()) } } } @@ -473,15 +470,15 @@ func ruler() policy.AttributeRuleTypeEnum { func upsertAttr(ctx context.Context, s *sdk.SDK, auth, name string, values []string) (string, error) { av, err := - s.Attributes.CreateAttribute(ctx, &attributes.CreateAttributeRequest{ + s.Attributes.CreateAttribute(ctx, connect.NewRequest(&attributes.CreateAttributeRequest{ NamespaceId: auth, Name: name, Rule: ruler(), Values: values, - }) + })) if err != nil { slog.Error("CreateAttribute", "err", err, "auth", auth, "name", name, "values", values, "rule", ruler()) return "", err } - return av.Attribute.GetId(), nil + return av.Msg.Attribute.GetId(), nil } diff --git a/examples/cmd/authorization.go b/examples/cmd/authorization.go index 9d3364e21b..afb3d54bcb 100644 --- a/examples/cmd/authorization.go +++ b/examples/cmd/authorization.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/sdk" @@ -26,7 +27,6 @@ func authorizationExamples() error { slog.Error("could not connect", slog.Any("error", err)) return err } - defer s.Close() // request decision on "TRANSMIT" Action actions := []*policy.Action{{ @@ -62,15 +62,15 @@ func authorizationExamples() error { decisionRequest := &authorization.GetDecisionsRequest{DecisionRequests: drs} slog.Info(fmt.Sprintf("Submitting decision request: %s", protojson.Format(decisionRequest))) - decisionResponse, err := s.Authorization.GetDecisions(context.Background(), decisionRequest) + decisionResponse, err := s.Authorization.GetDecisions(context.Background(), connect.NewRequest(decisionRequest)) if err != nil { return err } - slog.Info(fmt.Sprintf("Received decision response: %s", protojson.Format(decisionResponse))) + slog.Info(fmt.Sprintf("Received decision response: %s", protojson.Format(decisionResponse.Msg))) // map response back to entity chain id decisionsByEntityChain := make(map[string]*authorization.DecisionResponse) - for _, dr := range decisionResponse.DecisionResponses { + for _, dr := range decisionResponse.Msg.DecisionResponses { decisionsByEntityChain[dr.EntityChainId] = dr } diff --git a/examples/cmd/benchmark_decision.go b/examples/cmd/benchmark_decision.go index d34998bc23..ea5a66f8cc 100644 --- a/examples/cmd/benchmark_decision.go +++ b/examples/cmd/benchmark_decision.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/policy" "github.com/spf13/cobra" @@ -34,7 +35,7 @@ func runDecisionBenchmark(cmd *cobra.Command, args []string) error { } start := time.Now() - res, err := client.Authorization.GetDecisions(context.Background(), &authorization.GetDecisionsRequest{ + res, err := client.Authorization.GetDecisions(context.Background(), connect.NewRequest(&authorization.GetDecisionsRequest{ DecisionRequests: []*authorization.DecisionRequest{ { Actions: []*policy.Action{{Value: &policy.Action_Standard{ @@ -48,14 +49,14 @@ func runDecisionBenchmark(cmd *cobra.Command, args []string) error { ResourceAttributes: ras, }, }, - }) + })) end := time.Now() totalTime := end.Sub(start) numberApproved := 0 numberDenied := 0 if err == nil { - for _, dr := range res.GetDecisionResponses() { + for _, dr := range res.Msg.GetDecisionResponses() { if dr.Decision == authorization.DecisionResponse_DECISION_PERMIT { numberApproved += 1 } else { diff --git a/examples/cmd/decrypt.go b/examples/cmd/decrypt.go index f8040cc5ad..33884edb7a 100644 --- a/examples/cmd/decrypt.go +++ b/examples/cmd/decrypt.go @@ -55,7 +55,6 @@ func decrypt(cmd *cobra.Command, args []string) error { } } } - client.Close() return nil } diff --git a/examples/cmd/kas.go b/examples/cmd/kas.go index cd7972dd87..0939a0dca6 100644 --- a/examples/cmd/kas.go +++ b/examples/cmd/kas.go @@ -6,6 +6,7 @@ import ( "log/slog" "strings" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/kasregistry" "github.com/opentdf/platform/sdk" @@ -69,9 +70,8 @@ func listKases(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } - defer s.Close() - r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), &kasregistry.ListKeyAccessServersRequest{}) + r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) if err != nil { slog.Error("ListKeyAccessServers", "error", err) return err @@ -79,12 +79,12 @@ func listKases(cmd *cobra.Command) error { slog.Info("listing kas registry") - if len(r.GetKeyAccessServers()) == 0 { + if len(r.Msg.GetKeyAccessServers()) == 0 { cmd.Println("no key access servers registered") return nil } - for _, k := range r.GetKeyAccessServers() { + for _, k := range r.Msg.GetKeyAccessServers() { if longformat { fmt.Printf("%s\t%s\t%s\n", k.GetUri(), k.GetId(), k.GetPublicKey()) } else { @@ -95,12 +95,12 @@ func listKases(cmd *cobra.Command) error { } func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *policy.PublicKey) (string, error) { - r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) + r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(ctx, connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) if err != nil { slog.Error("ListKeyAccessServers", "err", err) return "", err } - for _, ki := range r.GetKeyAccessServers() { + for _, ki := range r.Msg.GetKeyAccessServers() { if strings.ToLower(uri) == strings.ToLower(ki.GetUri()) { oldpk := ki.GetPublicKey() recreate := false @@ -114,7 +114,7 @@ func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *poli if !recreate { return ki.GetId(), nil } - _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(ctx, &kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()}) + _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(ctx, connect.NewRequest(&kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()})) if err != nil { slog.Error("DeleteKeyAccessServer", "err", err) return "", err @@ -130,15 +130,15 @@ func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *poli Remote: uri + "/v2/kas_public_key", } } - ur, err := s.KeyAccessServerRegistry.CreateKeyAccessServer(ctx, &kasregistry.CreateKeyAccessServerRequest{ + ur, err := s.KeyAccessServerRegistry.CreateKeyAccessServer(ctx, connect.NewRequest(&kasregistry.CreateKeyAccessServerRequest{ Uri: uri, PublicKey: pk, - }) + })) if err != nil { slog.Error("CreateKeyAccessServer", "uri", uri, "publicKey", uri+"/v2/kas_public_key") return "", err } - return ur.KeyAccessServer.GetId(), nil + return ur.Msg.KeyAccessServer.GetId(), nil } func algString2Proto(a string) policy.KasPublicKeyAlgEnum { @@ -157,7 +157,6 @@ func updateKas(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } - defer s.Close() var pk *policy.PublicKey switch { @@ -206,17 +205,16 @@ func removeKas(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } - defer s.Close() - r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), &kasregistry.ListKeyAccessServersRequest{}) + r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) if err != nil { slog.Error("ListKeyAccessServers", "err", err) return err } deletedSomething := false - for _, ki := range r.GetKeyAccessServers() { + for _, ki := range r.Msg.GetKeyAccessServers() { if strings.ToLower(kas) == strings.ToLower(ki.GetUri()) { - _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(cmd.Context(), &kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()}) + _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(cmd.Context(), connect.NewRequest(&kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()})) if err != nil { slog.Error("DeleteKeyAccessServer", "err", err) return err diff --git a/examples/go.mod b/examples/go.mod index 9d9c6e01b8..04b308f3d2 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -5,6 +5,7 @@ go 1.24.0 toolchain go1.24.2 require ( + connectrpc.com/connect v1.18.1 github.com/opentdf/platform/lib/ocrypto v0.1.9 github.com/opentdf/platform/protocol/go v0.3.2 github.com/opentdf/platform/sdk v0.4.4 diff --git a/examples/go.sum b/examples/go.sum index 07aa828837..ca8720a84f 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,5 +1,7 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 h1:LEXWFH/xZ5oOWrC3oOtHbUyBdzRWMCPpAQmKC9v05mA= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1/go.mod h1:XF+P8+RmfdufmIYpGUC+6bF7S+IlmHDEnCrO3OXaUAQ= +connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= +connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= diff --git a/sdk/audit/metadata_adding_interceptor.go b/sdk/audit/metadata_adding_interceptor.go index f042e4f530..0f05de6a5f 100644 --- a/sdk/audit/metadata_adding_interceptor.go +++ b/sdk/audit/metadata_adding_interceptor.go @@ -3,6 +3,7 @@ package audit import ( "context" + "connectrpc.com/connect" "github.com/google/uuid" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "google.golang.org/grpc" @@ -46,3 +47,33 @@ func MetadataAddingClientInterceptor( return err } + +func MetadataAddingConnectInterceptor() connect.UnaryInterceptorFunc { + return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + // Only apply to outgoing client requests + if !req.Spec().IsClient { + return next(ctx, req) + } + + // Get any existing request ID from context + requestID, ok := ctx.Value(RequestIDContextKey).(uuid.UUID) + if !ok || requestID == uuid.Nil { + requestID = uuid.New() + } + req.Header().Set(string(RequestIDHeaderKey), requestID.String()) + + // Add the request IP to a custom header so it is preserved + if requestIP, ok := ctx.Value(RequestIPContextKey).(string); ok { + req.Header().Set(string(RequestIPHeaderKey), requestIP) + } + + // Add the actor ID from the request so it is preserved if we need it + if actorID, ok := ctx.Value(ActorIDContextKey).(string); ok { + req.Header().Set(string(ActorIDHeaderKey), actorID) + } + + return next(ctx, req) + } + }) +} diff --git a/sdk/auth/token_adding_interceptor.go b/sdk/auth/token_adding_interceptor.go index 901e26b9fe..72e7b4c469 100644 --- a/sdk/auth/token_adding_interceptor.go +++ b/sdk/auth/token_adding_interceptor.go @@ -11,6 +11,7 @@ import ( "net/http" "time" + "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" @@ -82,6 +83,38 @@ func (i TokenAddingInterceptor) AddCredentials( return err } +func (i TokenAddingInterceptor) AddCredentialsConnect() connect.UnaryInterceptorFunc { + return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + return func( + ctx context.Context, + req connect.AnyRequest, + ) (connect.AnyResponse, error) { + accessToken, err := i.tokenSource.AccessToken(ctx, i.httpClient) + if err != nil { + slog.ErrorContext(ctx, "error getting access token", "error", err) + return nil, connect.NewError(connect.CodeUnauthenticated, err) + } + + // Add Authorization header + req.Header().Set("Authorization", fmt.Sprintf("DPoP %s", accessToken)) + + // Add DPoP header if possible + dpopTok, err := i.GetDPoPToken(req.Spec().Procedure, http.MethodPost, string(accessToken)) + if err == nil { + req.Header().Set("DPoP", dpopTok) + } else { + // since we don't have a setting about whether DPoP is in use on the client and this request _could_ succeed if + // they are talking to a server where DPoP is not required we will just let this through. this method is extremely + // unlikely to fail so hopefully this isn't confusing + slog.ErrorContext(ctx, "error getting DPoP token for outgoing request. Request will not have DPoP token", "error", err) + } + + // Proceed with the RPC + return next(ctx, req) + } + }) +} + func (i TokenAddingInterceptor) GetDPoPToken(path, method, accessToken string) (string, error) { tok, err := i.tokenSource.MakeToken(func(key jwk.Key) ([]byte, error) { jtiBytes := make([]byte, JTILength) diff --git a/sdk/bulk.go b/sdk/bulk.go index b2ece7f51e..a90957f22b 100644 --- a/sdk/bulk.go +++ b/sdk/bulk.go @@ -167,7 +167,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { } } - kasClient := newKASClient(s.dialOptions, s.tokenSource, s.kasSessionKey) + kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey) allRewrapResp := make(map[string][]kaoResult) var err error for kasurl, rewrapRequests := range kasRewrapRequests { diff --git a/sdk/go.mod b/sdk/go.mod index c6a509b07a..ab6e3f842f 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -5,6 +5,7 @@ go 1.24.0 toolchain go1.24.2 require ( + connectrpc.com/connect v1.18.1 github.com/Masterminds/semver/v3 v3.3.1 github.com/google/uuid v1.6.0 github.com/gowebpki/jcs v1.0.1 diff --git a/sdk/go.sum b/sdk/go.sum index 1ee9aa961e..d3aafe3b5e 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -1,5 +1,7 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 h1:LEXWFH/xZ5oOWrC3oOtHbUyBdzRWMCPpAQmKC9v05mA= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1/go.mod h1:XF+P8+RmfdufmIYpGUC+6bF7S+IlmHDEnCrO3OXaUAQ= +connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= +connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= diff --git a/sdk/granter.go b/sdk/granter.go index 5e0f32c5ab..67fbfdcd4a 100644 --- a/sdk/granter.go +++ b/sdk/granter.go @@ -10,8 +10,10 @@ import ( "sort" "strings" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" ) var ErrInvalid = errors.New("invalid type") @@ -221,18 +223,18 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant { } // Gets a list of directory of KAS grants for a list of attribute FQNs -func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attributes.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { +func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attributesconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { fqnsStr := make([]string, len(fqns)) for i, v := range fqns { fqnsStr[i] = v.String() } - av, err := as.GetAttributeValuesByFqns(ctx, &attributes.GetAttributeValuesByFqnsRequest{ + av, err := as.GetAttributeValuesByFqns(ctx, connect.NewRequest(&attributes.GetAttributeValuesByFqnsRequest{ Fqns: fqnsStr, WithValue: &policy.AttributeValueSelector{ WithKeyAccessGrants: true, }, - }) + })) if err != nil { return granter{}, err } @@ -241,7 +243,7 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attrib policy: fqns, grants: make(map[string]*keyAccessGrant), } - for fqnstr, pair := range av.GetFqnAttributeValues() { + for fqnstr, pair := range av.Msg.GetFqnAttributeValues() { fqn, err := NewAttributeValueFQN(fqnstr) if err != nil { return grants, err diff --git a/sdk/granter_test.go b/sdk/granter_test.go index cc748d6889..aaded98df1 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -9,11 +9,12 @@ import ( "strings" "testing" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" ) const ( @@ -503,16 +504,16 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) { var listAttributeResp attributes.ListAttributesResponse type mockAttributesClient struct { - attributes.AttributesServiceClient + attributesconnect.AttributesServiceClient } -func (*mockAttributesClient) ListAttributes(_ context.Context, _ *attributes.ListAttributesRequest, _ ...grpc.CallOption) (*attributes.ListAttributesResponse, error) { - return &listAttributeResp, nil +func (*mockAttributesClient) ListAttributes(_ context.Context, _ *connect.Request[attributes.ListAttributesRequest]) (*connect.Response[attributes.ListAttributesResponse], error) { + return connect.NewResponse(&listAttributeResp), nil } -func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *attributes.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValuesByFqnsResponse, error) { +func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *connect.Request[attributes.GetAttributeValuesByFqnsRequest]) (*connect.Response[attributes.GetAttributeValuesByFqnsResponse], error) { av := make(map[string]*attributes.GetAttributeValuesByFqnsResponse_AttributeAndValue) - for _, v := range req.GetFqns() { + for _, v := range req.Msg.GetFqns() { vfqn, err := NewAttributeValueFQN(v) if err != nil { return nil, err @@ -524,9 +525,9 @@ func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *at } } - return &attributes.GetAttributeValuesByFqnsResponse{ + return connect.NewResponse(&attributes.GetAttributeValuesByFqnsResponse{ FqnAttributeValues: av, - }, nil + }), nil } // Tests titles are written in the form [{attr}.{value}] => [{resulting kas boolean exp}] diff --git a/sdk/kas_client.go b/sdk/kas_client.go index ae6dcde964..18aabd2fed 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -6,17 +6,19 @@ import ( "errors" "fmt" "net" + "net/http" "net/url" "time" + "connectrpc.com/connect" "google.golang.org/protobuf/encoding/protojson" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/sdk/auth" - "google.golang.org/grpc" ) const ( @@ -27,7 +29,8 @@ const ( type KASClient struct { accessTokenSource auth.AccessTokenSource - dialOptions []grpc.DialOption + httpClient *http.Client + connectOptions []connect.ClientOption sessionKey ocrypto.KeyPair // Set this to enable legacy, non-batch rewrap requests @@ -45,10 +48,11 @@ type decryptor interface { Decrypt(ctx context.Context, results []kaoResult) (int, error) } -func newKASClient(dialOptions []grpc.DialOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair) *KASClient { +func newKASClient(httpClient *http.Client, options []connect.ClientOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair) *KASClient { return &KASClient{ accessTokenSource: accessTokenSource, - dialOptions: dialOptions, + httpClient: httpClient, + connectOptions: options, sessionKey: sessionKey, supportSingleRewrapEndpoint: true, } @@ -60,27 +64,22 @@ func (k *KASClient) makeRewrapRequest(ctx context.Context, requests []*kas.Unsig if err != nil { return nil, err } - grpcAddress, err := getGRPCAddress(requests[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetKasUrl()) + kasURL := requests[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetKasUrl() + _, err = url.Parse(kasURL) if err != nil { - return nil, err - } - - conn, err := grpc.NewClient(grpcAddress, k.dialOptions...) - if err != nil { - return nil, fmt.Errorf("error connecting to sas: %w", err) + return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasURL, err) } - defer conn.Close() - serviceClient := kas.NewAccessServiceClient(conn) + serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, kasURL, k.connectOptions...) - response, err := serviceClient.Rewrap(ctx, rewrapRequest) + response, err := serviceClient.Rewrap(ctx, connect.NewRequest(rewrapRequest)) if err != nil { return upgradeRewrapErrorV1(err, requests) } - upgradeRewrapResponseV1(response, requests) + upgradeRewrapResponseV1(response.Msg, requests) - return response, nil + return response.Msg, nil } // convert v1 responses to v2 @@ -422,23 +421,18 @@ func (c *kasKeyCache) store(ki KASInfo) { c.c[cacheKey] = timeStampedKASInfo{ki, time.Now()} } -func (s SDK) getPublicKey(ctx context.Context, url, algorithm string) (*KASInfo, error) { +func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm string) (*KASInfo, error) { if s.kasKeyCache != nil { - if cachedValue := s.kasKeyCache.get(url, algorithm); nil != cachedValue { + if cachedValue := s.kasKeyCache.get(kasurl, algorithm); nil != cachedValue { return cachedValue, nil } } - grpcAddress, err := getGRPCAddress(url) - if err != nil { - return nil, err - } - conn, err := grpc.NewClient(grpcAddress, s.dialOptions...) + _, err := url.Parse(kasurl) if err != nil { - return nil, fmt.Errorf("error connecting to grpc service at %s: %w", url, err) + return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasurl, err) } - defer conn.Close() - serviceClient := kas.NewAccessServiceClient(conn) + serviceClient := kasconnect.NewAccessServiceClient(s.conn.Client, kasurl, s.conn.Options...) req := kas.PublicKeyRequest{ Algorithm: algorithm, @@ -446,21 +440,21 @@ func (s SDK) getPublicKey(ctx context.Context, url, algorithm string) (*KASInfo, if s.config.tdfFeatures.noKID { req.V = "1" } - resp, err := serviceClient.PublicKey(ctx, &req) + resp, err := serviceClient.PublicKey(ctx, connect.NewRequest(&req)) if err != nil { return nil, fmt.Errorf("error making request to KAS: %w", err) } - kid := resp.GetKid() + kid := resp.Msg.GetKid() if s.config.tdfFeatures.noKID { kid = "" } ki := KASInfo{ - URL: url, + URL: kasurl, Algorithm: algorithm, KID: kid, - PublicKey: resp.GetPublicKey(), + PublicKey: resp.Msg.GetPublicKey(), } if s.kasKeyCache != nil { s.kasKeyCache.store(ki) diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 496c195f5d..0750bbd9c5 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -5,6 +5,7 @@ import ( "net/http" "testing" + "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" @@ -15,7 +16,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "google.golang.org/grpc" "google.golang.org/protobuf/encoding/protojson" ) @@ -58,12 +58,12 @@ func getTokenSource(t *testing.T) FakeAccessTokenSource { } func TestCreatingRequest(t *testing.T) { - var dialOption []grpc.DialOption + var options []connect.ClientOption tokenSource := getTokenSource(t) kasKey, err := ocrypto.NewRSAKeyPair(tdf3KeySize) require.NoError(t, err, "error creating RSA Key") - client := newKASClient(dialOption, tokenSource, &kasKey) + client := newKASClient(nil, options, tokenSource, &kasKey) require.NoError(t, err) keyAccess := []*kaspb.UnsignedRewrapRequest_WithPolicyRequest{ @@ -125,7 +125,7 @@ func TestCreatingRequest(t *testing.T) { } func Test_StoreKASKeys(t *testing.T) { - s, err := New("localhost:8080", + s, err := New("http://localhost:8080", WithPlatformConfiguration(PlatformConfiguration{ "idp": map[string]interface{}{ "issuer": "https://example.org", diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index b5ef3f153e..97d97d1823 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -1079,7 +1079,7 @@ func (s SDK) getNanoRewrapKey(ctx context.Context, decryptor *NanoTDFDecryptHand } } - client := newKASClient(s.dialOptions, s.tokenSource, nil) + client := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, nil) kasURL, err := decryptor.header.kasURL.GetURL() if err != nil { return nil, fmt.Errorf("nano header kasUrl: %w", err) diff --git a/sdk/nanotdf_test.go b/sdk/nanotdf_test.go index 6cb7990f41..2854548ec3 100644 --- a/sdk/nanotdf_test.go +++ b/sdk/nanotdf_test.go @@ -308,7 +308,7 @@ func TestCreateNanoTDF(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s, err := New("localhost:8080", WithPlatformConfiguration(PlatformConfiguration{})) + s, err := New("http://localhost:8080", WithPlatformConfiguration(PlatformConfiguration{})) require.NoError(t, err) _, err = s.CreateNanoTDF(tt.writer, tt.reader, tt.config) if tt.expectedError != "" { diff --git a/sdk/options.go b/sdk/options.go index 78d4f5ad7b..4e25fc7e2b 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -5,29 +5,33 @@ import ( "crypto/tls" "net/http" + "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/auth/oauth" "github.com/opentdf/platform/sdk/httputil" "golang.org/x/oauth2" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" ) type Option func(*config) +type ConnectRpcConnection struct { + Client *http.Client + Endpoint string + Options []connect.ClientOption +} + // Internal config struct for building SDK options. type config struct { // Platform configuration structure is subject to change. Consume via accessor methods. - PlatformConfiguration PlatformConfiguration - dialOption grpc.DialOption - httpClient *http.Client - clientCredentials *oauth.ClientCredentials - tokenExchange *oauth.TokenExchangeInfo - tokenEndpoint string - scopes []string - extraDialOptions []grpc.DialOption + PlatformConfiguration PlatformConfiguration + extraClientOptions []connect.ClientOption + httpClient *http.Client + clientCredentials *oauth.ClientCredentials + tokenExchange *oauth.TokenExchangeInfo + tokenEndpoint string + scopes []string + // extraDialOptions []grpc.DialOption certExchange *oauth.CertExchangeInfo kasSessionKey *ocrypto.RsaKeyPair dpopKey *ocrypto.RsaKeyPair @@ -36,8 +40,8 @@ type config struct { nanoFeatures nanoFeatures customAccessTokenSource auth.AccessTokenSource oauthAccessTokenSource oauth2.TokenSource - coreConn *grpc.ClientConn - entityResolutionConn *grpc.ClientConn + coreConn *ConnectRpcConnection + entityResolutionConn *ConnectRpcConnection collectionStore *collectionStore shouldValidatePlatformConnectivity bool } @@ -56,9 +60,9 @@ type nanoFeatures struct { type PlatformConfiguration map[string]interface{} -func (c *config) build() []grpc.DialOption { - return []grpc.DialOption{c.dialOption} -} +// func (c *config) build() []grpc.DialOption { +// return []grpc.DialOption{c.dialOption} +// } // WithInsecureSkipVerifyConn returns an Option that sets up HTTPS connection without verification. func WithInsecureSkipVerifyConn() Option { @@ -66,9 +70,7 @@ func WithInsecureSkipVerifyConn() Option { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, // #nosec G402 - } - c.dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) - // used by http client + } // used by http client c.httpClient = httputil.SafeHTTPClientWithTLSConfig(tlsConfig) } } @@ -83,7 +85,6 @@ func WithStoreCollectionHeaders() Option { // WithInsecurePlaintextConn returns an Option that sets up HTTP connection sent in the clear. func WithInsecurePlaintextConn() Option { return func(c *config) { - c.dialOption = grpc.WithTransportCredentials(insecure.NewCredentials()) // used by http client // FIXME anything to do here c.httpClient = httputil.SafeHTTPClient() @@ -126,20 +127,20 @@ func WithOAuthAccessTokenSource(t oauth2.TokenSource) Option { } // Deprecated: Use WithCustomCoreConnection instead -func WithCustomPolicyConnection(conn *grpc.ClientConn) Option { +func WithCustomPolicyConnection(conn *ConnectRpcConnection) Option { return func(c *config) { c.coreConn = conn } } // Deprecated: Use WithCustomCoreConnection instead -func WithCustomAuthorizationConnection(conn *grpc.ClientConn) Option { +func WithCustomAuthorizationConnection(conn *ConnectRpcConnection) Option { return func(c *config) { c.coreConn = conn } } -func WithCustomEntityResolutionConnection(conn *grpc.ClientConn) Option { +func WithCustomEntityResolutionConnection(conn *ConnectRpcConnection) Option { return func(c *config) { c.entityResolutionConn = conn } @@ -156,11 +157,11 @@ func WithTokenExchange(subjectToken string, audience []string) Option { } } -func WithExtraDialOptions(dialOptions ...grpc.DialOption) Option { - return func(c *config) { - c.extraDialOptions = dialOptions - } -} +// func WithExtraDialOptions(dialOptions ...grpc.DialOption) Option { +// return func(c *config) { +// c.extraDialOptions = dialOptions +// } +// } // The session key pair is used to encrypt responses from KAS for a given session // and can be reused across an entire session. @@ -182,7 +183,7 @@ func WithSessionSignerRSA(key *rsa.PrivateKey) Option { } } -func WithCustomWellknownConnection(conn *grpc.ClientConn) Option { +func WithCustomWellknownConnection(conn *ConnectRpcConnection) Option { return func(c *config) { c.coreConn = conn } @@ -220,12 +221,18 @@ func WithNoKIDInKAO() Option { } // WithCoreConnection returns an Option that sets up a connection to the core platform -func WithCustomCoreConnection(conn *grpc.ClientConn) Option { +func WithCustomCoreConnection(conn *ConnectRpcConnection) Option { return func(c *config) { c.coreConn = conn } } +func WithExtraClientOptions(opts ...connect.ClientOption) Option { + return func(c *config) { + c.extraClientOptions = opts + } +} + // WithNoKIDInNano disables storing the KID in the KAS ResourceLocator. // This allows generating NanoTDF files that are compatible with legacy file formats (no KID). func WithNoKIDInNano() Option { diff --git a/sdk/sdk.go b/sdk/sdk.go index dd4bfcc851..0d852169e9 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -9,32 +9,31 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/url" "strings" + "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" - "github.com/opentdf/platform/protocol/go/authorization" - "github.com/opentdf/platform/protocol/go/entityresolution" + "github.com/opentdf/platform/protocol/go/authorization/authorizationconnect" + "github.com/opentdf/platform/protocol/go/entityresolution/entityresolutionconnect" "github.com/opentdf/platform/protocol/go/policy" - "github.com/opentdf/platform/protocol/go/policy/actions" - "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/kasregistry" - "github.com/opentdf/platform/protocol/go/policy/keymanagement" - "github.com/opentdf/platform/protocol/go/policy/namespaces" - "github.com/opentdf/platform/protocol/go/policy/registeredresources" - "github.com/opentdf/platform/protocol/go/policy/resourcemapping" - "github.com/opentdf/platform/protocol/go/policy/subjectmapping" - "github.com/opentdf/platform/protocol/go/policy/unsafe" + "github.com/opentdf/platform/protocol/go/policy/actions/actionsconnect" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" + "github.com/opentdf/platform/protocol/go/policy/keymanagement/keymanagementconnect" + "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" + "github.com/opentdf/platform/protocol/go/policy/registeredresources/registeredresourcesconnect" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" + "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" + "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/httputil" "github.com/opentdf/platform/sdk/internal/archive" "github.com/xeipuuv/gojsonschema" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" healthpb "google.golang.org/grpc/health/grpc_health_v1" ) @@ -64,35 +63,34 @@ type SDK struct { config *kasKeyCache *collectionStore - conn *grpc.ClientConn - dialOptions []grpc.DialOption + conn *ConnectRpcConnection tokenSource auth.AccessTokenSource - Actions actions.ActionServiceClient - Attributes attributes.AttributesServiceClient - Authorization authorization.AuthorizationServiceClient - EntityResoution entityresolution.EntityResolutionServiceClient - KeyAccessServerRegistry kasregistry.KeyAccessServerRegistryServiceClient - Namespaces namespaces.NamespaceServiceClient - RegisteredResources registeredresources.RegisteredResourcesServiceClient - ResourceMapping resourcemapping.ResourceMappingServiceClient - SubjectMapping subjectmapping.SubjectMappingServiceClient - Unsafe unsafe.UnsafeServiceClient - KeyManagement keymanagement.KeyManagementServiceClient - wellknownConfiguration wellknownconfiguration.WellKnownServiceClient + Actions actionsconnect.ActionServiceClient + Attributes attributesconnect.AttributesServiceClient + Authorization authorizationconnect.AuthorizationServiceClient + EntityResoution entityresolutionconnect.EntityResolutionServiceClient + KeyAccessServerRegistry kasregistryconnect.KeyAccessServerRegistryServiceClient + Namespaces namespacesconnect.NamespaceServiceClient + RegisteredResources registeredresourcesconnect.RegisteredResourcesServiceClient + ResourceMapping resourcemappingconnect.ResourceMappingServiceClient + SubjectMapping subjectmappingconnect.SubjectMappingServiceClient + Unsafe unsafeconnect.UnsafeServiceClient + KeyManagement keymanagementconnect.KeyManagementServiceClient + wellknownConfiguration wellknownconfigurationconnect.WellKnownServiceClient } func New(platformEndpoint string, opts ...Option) (*SDK, error) { var ( - platformConn *grpc.ClientConn // Connection to the platform - ersConn *grpc.ClientConn // Connection to ERS (possibly remote) + platformConn *ConnectRpcConnection // Connection to the platform + ersConn *ConnectRpcConnection // Connection to ERS (possibly remote) err error ) // Set default options cfg := &config{ - dialOption: grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + httpClient: httputil.SafeHTTPClientWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, - })), + }), } // Apply options @@ -114,25 +112,23 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { cfg.kasSessionKey = &key } - // once we change KAS to use standard DPoP we can put this all in the `build()` method - dialOptions := append([]grpc.DialOption{}, cfg.build()...) - // Add extra grpc dial options if provided. This is useful during tests. - if len(cfg.extraDialOptions) > 0 { - dialOptions = append(dialOptions, cfg.extraDialOptions...) - } - - unsanitizedPlatformEndpoint := platformEndpoint // IF IPC is disabled we build a validated healthy connection to the platform if !cfg.ipc { - platformEndpoint, err = SanitizePlatformEndpoint(platformEndpoint) - if err != nil { - return nil, fmt.Errorf("%w [%v]: %w", ErrPlatformEndpointMalformed, platformEndpoint, err) + if IsPlatformEndpointMalformed(platformEndpoint) { + return nil, fmt.Errorf("%w [%v]", ErrPlatformEndpointMalformed, platformEndpoint) } if cfg.shouldValidatePlatformConnectivity { - err = ValidateHealthyPlatformConnection(platformEndpoint, dialOptions) - if err != nil { - return nil, err + if cfg.coreConn != nil { + err = ValidateHealthyPlatformConnection(cfg.coreConn.Endpoint, cfg.coreConn.Client) + if err != nil { + return nil, err + } + } else { + err = ValidateHealthyPlatformConnection(platformEndpoint, cfg.httpClient) + if err != nil { + return nil, err + } } } } @@ -148,7 +144,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { return nil, errors.Join(ErrPlatformConfigFailed, err) } } else { - pcfg, err = fetchPlatformConfiguration(platformEndpoint, dialOptions) + pcfg, err = getPlatformConfiguration(&ConnectRpcConnection{Endpoint: platformEndpoint, Client: cfg.httpClient}) if err != nil { return nil, errors.Join(ErrPlatformConfigFailed, err) } @@ -162,13 +158,13 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { } } if cfg.PlatformConfiguration != nil { - cfg.PlatformConfiguration["platform_endpoint"] = unsanitizedPlatformEndpoint + cfg.PlatformConfiguration["platform_endpoint"] = platformEndpoint } - var uci []grpc.UnaryClientInterceptor + var uci []connect.Interceptor // Add request ID interceptor - uci = append(uci, audit.MetadataAddingClientInterceptor) + uci = append(uci, audit.MetadataAddingConnectInterceptor()) accessTokenSource, err := buildIDPTokenSource(cfg) if err != nil { @@ -176,19 +172,14 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { } if accessTokenSource != nil { interceptor := auth.NewTokenAddingInterceptorWithClient(accessTokenSource, cfg.httpClient) - uci = append(uci, interceptor.AddCredentials) + uci = append(uci, interceptor.AddCredentialsConnect()) } - dialOptions = append(dialOptions, grpc.WithChainUnaryInterceptor(uci...)) - // If coreConn is provided, use it as the platform connection if cfg.coreConn != nil { platformConn = cfg.coreConn } else { - platformConn, err = grpc.NewClient(platformEndpoint, dialOptions...) - if err != nil { - return nil, errors.Join(ErrGrpcDialFailed, err) - } + platformConn = &ConnectRpcConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: append(cfg.extraClientOptions, connect.WithInterceptors(uci...))} } if cfg.entityResolutionConn != nil { @@ -201,56 +192,29 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { config: *cfg, collectionStore: cfg.collectionStore, kasKeyCache: newKasKeyCache(), - conn: platformConn, - dialOptions: dialOptions, + conn: &ConnectRpcConnection{Client: platformConn.Client, Endpoint: platformConn.Endpoint, Options: platformConn.Options}, tokenSource: accessTokenSource, - Actions: actions.NewActionServiceClient(platformConn), - Attributes: attributes.NewAttributesServiceClient(platformConn), - Namespaces: namespaces.NewNamespaceServiceClient(platformConn), - RegisteredResources: registeredresources.NewRegisteredResourcesServiceClient(platformConn), - ResourceMapping: resourcemapping.NewResourceMappingServiceClient(platformConn), - SubjectMapping: subjectmapping.NewSubjectMappingServiceClient(platformConn), - Unsafe: unsafe.NewUnsafeServiceClient(platformConn), - KeyAccessServerRegistry: kasregistry.NewKeyAccessServerRegistryServiceClient(platformConn), - Authorization: authorization.NewAuthorizationServiceClient(platformConn), - EntityResoution: entityresolution.NewEntityResolutionServiceClient(ersConn), - KeyManagement: keymanagement.NewKeyManagementServiceClient(platformConn), - wellknownConfiguration: wellknownconfiguration.NewWellKnownServiceClient(platformConn), + Actions: actionsconnect.NewActionServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Attributes: attributesconnect.NewAttributesServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Namespaces: namespacesconnect.NewNamespaceServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + RegisteredResources: registeredresourcesconnect.NewRegisteredResourcesServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + ResourceMapping: resourcemappingconnect.NewResourceMappingServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + SubjectMapping: subjectmappingconnect.NewSubjectMappingServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Unsafe: unsafeconnect.NewUnsafeServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + KeyAccessServerRegistry: kasregistryconnect.NewKeyAccessServerRegistryServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Authorization: authorizationconnect.NewAuthorizationServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + EntityResoution: entityresolutionconnect.NewEntityResolutionServiceClient(ersConn.Client, ersConn.Endpoint, ersConn.Options...), + KeyManagement: keymanagementconnect.NewKeyManagementServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + wellknownConfiguration: wellknownconfigurationconnect.NewWellKnownServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), }, nil } -func SanitizePlatformEndpoint(e string) (string, error) { - // check if there's a scheme, if not, add https +func IsPlatformEndpointMalformed(e string) bool { u, err := url.ParseRequestURI(e) - if err != nil { - return "", errors.Join(fmt.Errorf("cannot parse platform endpoint [%s]", e), err) - } - if u.Host == "" { - // if the schema is missing add https. when the schema is missing the host is parsed as the scheme - newE := "https://" + e - u, err = url.ParseRequestURI(newE) - if err != nil { - return "", errors.Join(fmt.Errorf("cannot parse platform endpoint [%s]", newE), err) - } - if u.Host == "" { - return "", fmt.Errorf("invalid URL [%s], got empty hostname", newE) - } - } - - if strings.Contains(u.Hostname(), ":") { - return "", fmt.Errorf("invalid hostname [%s]. IPv6 addresses are not supported", u.Hostname()) - } - - p := u.Port() - if p == "" { - if u.Scheme == "http" { - p = "80" - } else { - p = "443" - } + if err != nil || u.Hostname() == "" || strings.Contains(u.Hostname(), ":") { + return true } - - return net.JoinHostPort(u.Hostname(), p), nil + return false } func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { @@ -303,23 +267,8 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { return ts, err } -// Close closes the underlying grpc.ClientConn. -func (s SDK) Close() error { - if s.collectionStore != nil { - s.collectionStore.close() - } - - if s.conn == nil { - return nil - } - if err := s.conn.Close(); err != nil { - return errors.Join(ErrShutdownFailed, err) - } - return nil -} - -// Conn returns the underlying grpc.ClientConn. -func (s SDK) Conn() *grpc.ClientConn { +// Conn returns the underlying http connection +func (s SDK) Conn() *ConnectRpcConnection { return s.conn } @@ -444,43 +393,34 @@ func IsValidNanoTdf(reader io.ReadSeeker) (bool, error) { return err == nil, err } -func fetchPlatformConfiguration(platformEndpoint string, dialOptions []grpc.DialOption) (PlatformConfiguration, error) { - conn, err := grpc.NewClient(platformEndpoint, dialOptions...) - if err != nil { - return nil, errors.Join(ErrGrpcDialFailed, err) - } - defer conn.Close() - - return getPlatformConfiguration(conn) -} - // Test connectability to the platform and validate a healthy status -func ValidateHealthyPlatformConnection(platformEndpoint string, dialOptions []grpc.DialOption) error { - conn, err := grpc.NewClient(platformEndpoint, dialOptions...) - if err != nil { - return errors.Join(ErrGrpcDialFailed, err) - } - defer conn.Close() +func ValidateHealthyPlatformConnection(platformEndpoint string, httpClient *http.Client) error { - req := healthpb.HealthCheckRequest{} - healthService := healthpb.NewHealthClient(conn) - resp, err := healthService.Check(context.Background(), &req) - if err != nil || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { + healthClient := connect.NewClient[healthpb.HealthCheckRequest, healthpb.HealthCheckResponse]( + httpClient, + platformEndpoint+"/grpc.health.v1.Health/Check", + ) + res, err := healthClient.CallUnary( + context.Background(), + connect.NewRequest(&healthpb.HealthCheckRequest{}), + ) + if err != nil || res.Msg.GetStatus() != healthpb.HealthCheckResponse_SERVING { return errors.Join(ErrPlatformUnreachable, err) } + return nil } -func getPlatformConfiguration(conn *grpc.ClientConn) (PlatformConfiguration, error) { +func getPlatformConfiguration(conn *ConnectRpcConnection) (PlatformConfiguration, error) { req := wellknownconfiguration.GetWellKnownConfigurationRequest{} - wellKnownConfig := wellknownconfiguration.NewWellKnownServiceClient(conn) + wellKnownConfig := wellknownconfigurationconnect.NewWellKnownServiceClient(conn.Client, conn.Endpoint) - response, err := wellKnownConfig.GetWellKnownConfiguration(context.Background(), &req) + response, err := wellKnownConfig.GetWellKnownConfiguration(context.Background(), connect.NewRequest(&req)) if err != nil { return nil, errors.Join(errors.New("unable to retrieve config information, and none was provided"), err) } // Get token endpoint - configuration := response.GetConfiguration() + configuration := response.Msg.GetConfiguration() return configuration.AsMap(), nil } diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index e3b67176c7..e92c3ede8d 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -6,18 +6,18 @@ import ( "reflect" "testing" - "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/kasregistry" - "github.com/opentdf/platform/protocol/go/policy/resourcemapping" - "github.com/opentdf/platform/protocol/go/policy/subjectmapping" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" "github.com/opentdf/platform/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( - goodPlatformEndpoint = "localhost:8080" - badPlatformEndpoint = "localhost:9999" + goodPlatformEndpoint = "http://localhost:8080" + badPlatformEndpoint = "http://localhost:9999" ) func GetMethods(i interface{}) []string { @@ -122,18 +122,6 @@ func Test_ShouldCreateNewSDK_NoCredentials(t *testing.T) { assert.NotNil(t, s) } -func TestNew_ShouldCloseConnections(t *testing.T) { - s, err := sdk.New(goodPlatformEndpoint, - sdk.WithPlatformConfiguration(sdk.PlatformConfiguration{ - "platform_issuer": "https://example.org", - }), - sdk.WithClientCredentials("myid", "mysecret", nil), - sdk.WithTokenEndpoint("https://example.org/token"), - ) - require.NoError(t, err) - require.NoError(t, s.Close()) -} - func TestNew_ShouldValidateGoodNanoTdf(t *testing.T) { goodNanoTdfStr := "TDFMABJsb2NhbGhvc3Q6ODA4MC9rYXOAAQIA2qvjMRfg7b27lT2kf9SwHRkDIg8ZXtfRoiIvdMUHq/gL5AUMfmv4Di8sKCyLkmUm/WITVj5hDeV/z4JmQ0JL7ZxqSmgZoK6TAHvkKhUly4zMEWMRXH8IktKhFKy1+fD+3qwDopqWAO5Nm2nYQqi75atEFckstulpNKg3N+Ul22OHr/ZuR127oPObBDYNRfktBdzoZbEQcPlr8q1B57q6y5SPZFjEzL9weK+uS5bUJWkF3nsHASo2bZw7IPhTZxoFVmCDjwvj6MbxNa7zG6aClHJ162zKxLLnD9TtIHuZ59R7LgiSieipXeExj+ky9OgIw5DfwyUuxsQLtKpMIAFPmLY9Hy2naUJxke0MT1EUBgastCq+YtFGslV9LJo/A8FtrRqludwtM0O+Z9FlAkZ1oNL7M7uOkLrh7eRrv+C1AAAX6FaBQoOtqnmyu6Jp+VzkxDddEeLRUyI=" goodDecodedData, err := base64.StdEncoding.DecodeString(goodNanoTdfStr) @@ -228,22 +216,22 @@ func TestNew_ShouldHaveSameMethods(t *testing.T) { }{ { name: "Attributes", - expected: GetMethods(reflect.TypeOf(attributes.NewAttributesServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(attributesconnect.NewAttributesServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.Attributes)), }, { name: "ResourceEncoding", - expected: GetMethods(reflect.TypeOf(resourcemapping.NewResourceMappingServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(resourcemappingconnect.NewResourceMappingServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.ResourceMapping)), }, { name: "SubjectEncoding", - expected: GetMethods(reflect.TypeOf(subjectmapping.NewSubjectMappingServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(subjectmappingconnect.NewSubjectMappingServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.SubjectMapping)), }, { name: "KeyAccessGrants", - expected: GetMethods(reflect.TypeOf(kasregistry.NewKeyAccessServerRegistryServiceClient(s.Conn()))), + expected: GetMethods(reflect.TypeOf(kasregistryconnect.NewKeyAccessServerRegistryServiceClient(s.Conn().Client, s.Conn().Endpoint))), actual: GetMethods(reflect.TypeOf(s.KeyAccessServerRegistry)), }, } @@ -273,97 +261,70 @@ func Test_New_ShouldFailWithDisconnectedPlatform(t *testing.T) { assert.Nil(t, s) } -func Test_ShouldSanitizePlatformEndpoint(t *testing.T) { +func TestIsPlatformEndpointMalformed(t *testing.T) { tests := []struct { - name string - endpoint string - expected string + name string + input string + expected bool + description string }{ { - name: "No scheme", - endpoint: "localhost:8080", - expected: "localhost:8080", - }, - { - name: "HTTP scheme with port", - endpoint: "http://localhost:8080", - expected: "localhost:8080", - }, - { - name: "HTTPS scheme with port", - endpoint: "https://localhost:8080", - expected: "localhost:8080", - }, - { - name: "HTTP scheme no port", - endpoint: "http://localhost", - expected: "localhost:80", - }, - { - name: "HTTPS scheme no port", - endpoint: "https://localhost", - expected: "localhost:443", - }, - { - name: "HTTPS scheme port (IP)", - endpoint: "https://192.168.1.1:8080", - expected: "192.168.1.1:8080", + name: "Valid URL with scheme and host", + input: "https://example.com", + expected: false, + description: "A valid URL with scheme and host should not be considered malformed.", }, { - name: "HTTPS scheme no port (IP)", - endpoint: "https://192.168.1.1", - expected: "192.168.1.1:443", + name: "Valid URL with scheme, host, and port", + input: "https://example.com:8080", + expected: false, + description: "A valid URL with scheme, host, and port should not be considered malformed.", }, { - name: "Malformed url", - endpoint: "http://localhost:8080:8080", - expected: "", + name: "Valid URL with path", + input: "https://example.com/path", + expected: false, + description: "A valid URL with a path should not be considered malformed.", }, { - name: "Malformed url", - endpoint: "http://localhost:8080:", - expected: "", + name: "Invalid URL with missing host", + input: "https://:8080", + expected: true, + description: "A URL with a missing host should be considered malformed.", }, { - name: "Malformed url", - endpoint: "http//localhost:8080:", - expected: "", + name: "Invalid URL with missing scheme", + input: "example.com", + expected: true, + description: "A URL without a scheme should be considered malformed.", }, { - name: "Malformed url", - endpoint: "//localhost", - expected: "", + name: "Invalid URL with invalid characters", + input: "https://exa mple.com", + expected: true, + description: "A URL with invalid characters should be considered malformed.", }, { - name: "Malformed url", - endpoint: "://localhost", - expected: "", + name: "Invalid URL with colon in hostname", + input: "https://example:com", + expected: true, + description: "A URL with a colon in the hostname should be considered malformed.", }, { - name: "Malformed url", - endpoint: "http/localhost", - expected: "", - }, - { - name: "Malformed url", - endpoint: "http:localhost", - expected: "", + name: "Empty input", + input: "", + expected: true, + description: "An empty input should be considered malformed.", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - actual, err := sdk.SanitizePlatformEndpoint(tt.endpoint) - if tt.expected == "" { - require.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected, actual) - } + result := sdk.IsPlatformEndpointMalformed(tt.input) + assert.Equal(t, tt.expected, result, tt.description) }) } } - func Test_GetType_NanoTDF(t *testing.T) { nano := "TDFMABJsb2NhbGhvc3Q6ODA4MC9rYXOAAQIA2qvjMRfg7b27lT2kf9SwHRkDIg8ZXtfRoiIvdMUHq/gL5AUMfmv4Di8sKCyLkmUm/WITVj5hDeV/z4JmQ0JL7ZxqSmgZoK6TAHvkKhUly4zMEWMRXH8IktKhFKy1+fD+3qwDopqWAO5Nm2nYQqi75atEFckstulpNKg3N+Ul22OHr/ZuR127oPObBDYNRfktBdzoZbEQcPlr8q1B57q6y5SPZFjEzL9weK+uS5bUJWkF3nsHASo2bZw7IPhTZxoFVmCDjwvj6MbxNa7zG6aClHJ162zKxLLnD9TtIHuZ59R7LgiSieipXeExj+ky9OgIw5DfwyUuxsQLtKpMIAFPmLY9Hy2naUJxke0MT1EUBgastCq+YtFGslV9LJo/A8FtrRqludwtM0O+Z9FlAkZ1oNL7M7uOkLrh7eRrv+C1AAAX6FaBQoOtqnmyu6Jp+VzkxDddEeLRUyI=" nanoDecoded, err := base64.StdEncoding.DecodeString(nano) diff --git a/sdk/tdf.go b/sdk/tdf.go index 8aa8e8cbe0..049611b0df 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -11,18 +11,20 @@ import ( "io" "log/slog" "math" + "net/http" "strconv" "strings" + "connectrpc.com/connect" "github.com/Masterminds/semver/v3" "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" "github.com/google/uuid" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/internal/archive" - "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -63,7 +65,8 @@ const ( // Loads and reads ZTDF files type Reader struct { tokenSource auth.AccessTokenSource - dialOptions []grpc.DialOption + httpClient *http.Client + connectOptions []connect.ClientOption manifest Manifest unencryptedMetadata []byte tdfReader archive.TDFReader @@ -654,13 +657,13 @@ func createPolicyObject(attributes []AttributeValueFQN) (PolicyObject, error) { return policyObj, nil } -func allowListFromKASRegistry(ctx context.Context, kasRegistryClient kasregistry.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { - kases, err := kasRegistryClient.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) +func allowListFromKASRegistry(ctx context.Context, kasRegistryClient kasregistryconnect.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { + kases, err := kasRegistryClient.ListKeyAccessServers(ctx, connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) if err != nil { return nil, fmt.Errorf("kasregistry.ListKeyAccessServers failed: %w", err) } kasAllowlist := AllowList{} - for _, kas := range kases.GetKeyAccessServers() { + for _, kas := range kases.Msg.GetKeyAccessServers() { err = kasAllowlist.Add(kas.GetUri()) if err != nil { return nil, fmt.Errorf("kasAllowlist.Add failed: %w", err) @@ -732,12 +735,13 @@ func (s SDK) LoadTDF(reader io.ReadSeeker, opts ...TDFReaderOption) (*Reader, er } return &Reader{ - tokenSource: s.tokenSource, - dialOptions: s.dialOptions, - tdfReader: tdfReader, - manifest: *manifestObj, - kasSessionKey: config.kasSessionKey, - config: *config, + tokenSource: s.tokenSource, + httpClient: s.conn.Client, + connectOptions: s.conn.Options, + tdfReader: tdfReader, + manifest: *manifestObj, + kasSessionKey: config.kasSessionKey, + config: *config, }, nil } @@ -1248,7 +1252,7 @@ func (r *Reader) buildKey(_ context.Context, results []kaoResult) error { // Unwraps the payload key, if possible, using the access service func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocognit // Better readability keeping it as is - kasClient := newKASClient(r.dialOptions, r.tokenSource, r.kasSessionKey) + kasClient := newKASClient(r.httpClient, r.connectOptions, r.tokenSource, r.kasSessionKey) var kaoResults []kaoResult reqFail := func(err error, req *kas.UnsignedRewrapRequest_WithPolicyRequest) { diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 8b3eb0699f..8b10e7b544 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -15,8 +15,8 @@ import ( "fmt" "io" "log/slog" - "net" - "net/url" + "net/http" + "net/http/httptest" "os" "path/filepath" "strconv" @@ -24,20 +24,23 @@ import ( "testing" "time" + "connectrpc.com/connect" "google.golang.org/protobuf/encoding/protojson" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/lib/ocrypto" kaspb "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/protocol/go/policy" attributespb "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" wellknownpb "github.com/opentdf/platform/protocol/go/wellknownconfiguration" - "google.golang.org/grpc" + wellknownconnect "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" "google.golang.org/grpc/codes" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" - "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/structpb" "github.com/stretchr/testify/assert" @@ -283,8 +286,9 @@ type keyInfo struct { type TDFSuite struct { suite.Suite - sdk *SDK - kases []FakeKas + sdk *SDK + kases []FakeKas + kasTestUrlLookup map[string]string } func (s *TDFSuite) SetupSuite() { @@ -323,31 +327,35 @@ func (s *TDFSuite) Test_SimpleTDF() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), }, - tdfReadOptions: []TDFReaderOption{}, + tdfReadOptions: []TDFReaderOption{ + WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]}), + }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), WithTargetMode("0.0.0"), }, - tdfReadOptions: []TDFReaderOption{}, - useHex: true, + tdfReadOptions: []TDFReaderOption{ + WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]}), + }, + useHex: true, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://d.kas/", + URL: s.kasTestUrlLookup["https://d.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -356,12 +364,13 @@ func (s *TDFSuite) Test_SimpleTDF() { }, tdfReadOptions: []TDFReaderOption{ WithSessionKeyType(ocrypto.EC256Key), + WithKasAllowlist([]string{s.kasTestUrlLookup["https://d.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://d.kas/", + URL: s.kasTestUrlLookup["https://d.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -371,6 +380,7 @@ func (s *TDFSuite) Test_SimpleTDF() { }, tdfReadOptions: []TDFReaderOption{ WithSessionKeyType(ocrypto.EC256Key), + WithKasAllowlist([]string{s.kasTestUrlLookup["https://d.kas/"]}), }, useHex: true, }, @@ -490,20 +500,20 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), }, tdfReadOptions: []TDFReaderOption{ - WithKasAllowlist([]string{"https://a.kas/"}), + WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -512,12 +522,12 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { tdfReadOptions: []TDFReaderOption{ WithKasAllowlist([]string{"https://nope-not-a-kas.com/kas"}), }, - expectedError: "KasAllowlist: kas url https://a.kas/ is not allowed", + expectedError: "KasAllowlist: kas url " + s.kasTestUrlLookup["https://a.kas/"] + " is not allowed", }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -526,12 +536,12 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { tdfReadOptions: []TDFReaderOption{ withKasAllowlist(AllowList{"nope-not-a-kas.com": true}), }, - expectedError: "KasAllowlist: kas url https://a.kas/ is not allowed", + expectedError: "KasAllowlist: kas url " + s.kasTestUrlLookup["https://a.kas/"] + " is not allowed", }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -545,7 +555,7 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -821,7 +831,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { { kasURLs := []KASInfo{ { - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }, } @@ -865,11 +875,12 @@ func (s *TDFSuite) Test_TDFWithAssertion() { var r *Reader if test.verifiers == nil { - r, err = s.sdk.LoadTDF(readSeeker, WithDisableAssertionVerification(test.disableAssertionVerification)) + r, err = s.sdk.LoadTDF(readSeeker, WithDisableAssertionVerification(test.disableAssertionVerification), WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]})) } else { r, err = s.sdk.LoadTDF(readSeeker, WithAssertionVerificationKeys(*test.verifiers), - WithDisableAssertionVerification(test.disableAssertionVerification)) + WithDisableAssertionVerification(test.disableAssertionVerification), + WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]})) } s.Require().NoError(err) @@ -1139,7 +1150,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() { { kasURLs := []KASInfo{ { - URL: "https://a.kas/", + URL: s.kasTestUrlLookup["https://a.kas/"], PublicKey: "", }, } @@ -1198,11 +1209,11 @@ func (s *TDFSuite) Test_TDFReader() { //nolint:gocognit // requires for testing payload: payload, // len: 62 kasInfoList: []KASInfo{ { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestUrlLookup["http://localhost:65432/"], PublicKey: mockRSAPublicKey1, }, { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestUrlLookup["http://localhost:65432/"], PublicKey: mockRSAPublicKey1, }, }, @@ -1298,11 +1309,11 @@ func (s *TDFSuite) Test_TDFReader() { //nolint:gocognit // requires for testing func (s *TDFSuite) Test_TDFReaderFail() { kasInfoList := []KASInfo{ { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestUrlLookup["http://localhost:65432/api/kas"], PublicKey: mockRSAPublicKey1, }, { - URL: "http://localhost:65432/api/kas", + URL: s.kasTestUrlLookup["http://localhost:65432/api/kas"], PublicKey: mockRSAPublicKey1, }, } @@ -1631,9 +1642,9 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 2759, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: "https://a.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "a"}, - {KAS: `https://c.kas/`, SplitID: "a"}, + {KAS: s.kasTestUrlLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "a"}, + {KAS: s.kasTestUrlLookup[`https://c.kas/`], SplitID: "a"}, }, }, { @@ -1642,9 +1653,9 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 2759, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: "https://a.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "b"}, - {KAS: "https://c.kas/", SplitID: "c"}, + {KAS: s.kasTestUrlLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "b"}, + {KAS: s.kasTestUrlLookup["https://c.kas/"], SplitID: "c"}, }, }, { @@ -1653,10 +1664,10 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 3351, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: "https://a.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "a"}, - {KAS: "https://b.kas/", SplitID: "b"}, - {KAS: "https://c.kas/", SplitID: "b"}, + {KAS: s.kasTestUrlLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "a"}, + {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "b"}, + {KAS: s.kasTestUrlLookup["https://c.kas/"], SplitID: "b"}, }, }, } { @@ -1880,7 +1891,7 @@ func (s *TDFSuite) startBackend() { } fwk := &FakeWellKnown{v: wellknownCfg} - fa := &FakeAttributes{} + fa := &FakeAttributes{s: s} kasesToMake := []struct { url, private, public string }{ @@ -1897,137 +1908,130 @@ func (s *TDFSuite) startBackend() { {kasNz, mockRSAPrivateKey3, mockRSAPublicKey3}, {kasUs, mockRSAPrivateKey1, mockRSAPublicKey1}, } - fkar := &FakeKASRegistry{kases: kasesToMake} - - listeners := make(map[string]*bufconn.Listener) - dialer := func(ctx context.Context, host string) (net.Conn, error) { - l, ok := listeners[host] - if !ok { - slog.ErrorContext(ctx, "bufconn: unable to dial host!", "ctx", ctx, "host", host) - return nil, fmt.Errorf("unknown host [%s]", host) - } - slog.InfoContext(ctx, "bufconn: dialing (local grpc)", "ctx", ctx, "host", host) - return l.Dial() - } + fkar := &FakeKASRegistry{kases: kasesToMake, s: s} s.kases = make([]FakeKas, 12) + s.kasTestUrlLookup = make(map[string]string, 12) + + var sdkPlatformUrl string + for i, ki := range kasesToMake { - grpcListener := bufconn.Listen(1024 * 1024) - url, err := url.Parse(ki.url) - s.Require().NoError(err) - var origin string - switch { - case url.Port() == "80": - origin = url.Host - case url.Scheme == "https": - origin = url.Host + ":443" - case url.Port() != "": - origin = url.Host - default: - origin = url.Hostname() - } - listeners[origin] = grpcListener - grpcServer := grpc.NewServer() + mux := http.NewServeMux() + s.kases[i] = FakeKas{ s: s, privateKey: ki.private, KASInfo: KASInfo{ URL: ki.url, PublicKey: ki.public, KID: "r1", Algorithm: "rsa:2048", }, legakeys: map[string]keyInfo{}, } - attributespb.RegisterAttributesServiceServer(grpcServer, fa) - kaspb.RegisterAccessServiceServer(grpcServer, &s.kases[i]) - wellknownpb.RegisterWellKnownServiceServer(grpcServer, fwk) - kasregistry.RegisterKeyAccessServerRegistryServiceServer(grpcServer, fkar) - - go func() { - err := grpcServer.Serve(grpcListener) - s.NoError(err) - }() + path, handler := attributesconnect.NewAttributesServiceHandler(fa) + mux.Handle(path, handler) + kasPath, kasHandler := kasconnect.NewAccessServiceHandler(&s.kases[i]) + mux.Handle(kasPath, kasHandler) + path, handler = wellknownconnect.NewWellKnownServiceHandler(fwk) + mux.Handle(path, handler) + path, handler = kasregistryconnect.NewKeyAccessServerRegistryServiceHandler(fkar) + mux.Handle(path, handler) + + server := httptest.NewServer(mux) + + // add to lookup reg + s.kasTestUrlLookup[s.kases[i].KASInfo.URL] = server.URL + // replace kasinfo url with httptest server url + s.kases[i].KASInfo.URL = server.URL + + if i == 0 { + sdkPlatformUrl = server.URL + } } ats := getTokenSource(s.T()) - sdk, err := New("localhost:65432", + sdk, err := New(sdkPlatformUrl, WithClientCredentials("test", "test", nil), withCustomAccessTokenSource(&ats), WithTokenEndpoint("http://localhost:65432/auth/token"), - WithInsecurePlaintextConn(), - WithExtraDialOptions(grpc.WithContextDialer(dialer))) + WithInsecurePlaintextConn()) s.Require().NoError(err) s.sdk = sdk } type FakeWellKnown struct { - wellknownpb.UnimplementedWellKnownServiceServer + wellknownconnect.UnimplementedWellKnownServiceHandler v map[string]interface{} } -func (f *FakeWellKnown) GetWellKnownConfiguration(_ context.Context, _ *wellknownpb.GetWellKnownConfigurationRequest) (*wellknownpb.GetWellKnownConfigurationResponse, error) { +func (f *FakeWellKnown) GetWellKnownConfiguration(_ context.Context, _ *connect.Request[wellknownpb.GetWellKnownConfigurationRequest]) (*connect.Response[wellknownpb.GetWellKnownConfigurationResponse], error) { cfg, err := structpb.NewStruct(f.v) if err != nil { return nil, err } - return &wellknownpb.GetWellKnownConfigurationResponse{ + return connect.NewResponse(&wellknownpb.GetWellKnownConfigurationResponse{ Configuration: cfg, - }, nil + }), nil } type FakeAttributes struct { - attributespb.UnimplementedAttributesServiceServer + attributesconnect.UnimplementedAttributesServiceHandler + s *TDFSuite } -func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *attributespb.GetAttributeValuesByFqnsRequest) (*attributespb.GetAttributeValuesByFqnsResponse, error) { +func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *connect.Request[attributespb.GetAttributeValuesByFqnsRequest]) (*connect.Response[attributespb.GetAttributeValuesByFqnsResponse], error) { r := make(map[string]*attributespb.GetAttributeValuesByFqnsResponse_AttributeAndValue) - for _, fqn := range in.GetFqns() { + for _, fqn := range in.Msg.GetFqns() { av, err := NewAttributeValueFQN(fqn) if err != nil { slog.Error("invalid fqn", "notfqn", fqn, "error", err) return nil, status.New(codes.InvalidArgument, fmt.Sprintf("invalid attribute fqn [%s]", fqn)).Err() } v := mockValueFor(av) + for i := range v.GetGrants() { + v.Grants[i].Uri = f.s.kasTestUrlLookup[v.Grants[i].Uri] + } r[fqn] = &attributespb.GetAttributeValuesByFqnsResponse_AttributeAndValue{ Attribute: v.GetAttribute(), Value: v, } } - return &attributespb.GetAttributeValuesByFqnsResponse{FqnAttributeValues: r}, nil + return connect.NewResponse(&attributespb.GetAttributeValuesByFqnsResponse{FqnAttributeValues: r}), nil } type FakeKASRegistry struct { - kasregistry.UnimplementedKeyAccessServerRegistryServiceServer + kasregistryconnect.UnimplementedKeyAccessServerRegistryServiceHandler + s *TDFSuite kases []struct { url, private, public string } } -func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) { +func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *connect.Request[kasregistry.ListKeyAccessServersRequest]) (*connect.Response[kasregistry.ListKeyAccessServersResponse], error) { resp := &kasregistry.ListKeyAccessServersResponse{ KeyAccessServers: make([]*policy.KeyAccessServer, 0, len(f.kases)), } for _, k := range f.kases { kas := &policy.KeyAccessServer{ - Uri: k.url, + Uri: f.s.kasTestUrlLookup[k.url], } resp.KeyAccessServers = append(resp.KeyAccessServers, kas) } - return resp, nil + return connect.NewResponse(resp), nil } type FakeKas struct { - kaspb.UnimplementedAccessServiceServer + kasconnect.UnimplementedAccessServiceHandler KASInfo privateKey string s *TDFSuite legakeys map[string]keyInfo } -func (f *FakeKas) Rewrap(_ context.Context, in *kaspb.RewrapRequest) (*kaspb.RewrapResponse, error) { - signedRequestToken := in.GetSignedRequestToken() +func (f *FakeKas) Rewrap(_ context.Context, in *connect.Request[kaspb.RewrapRequest]) (*connect.Response[kaspb.RewrapResponse], error) { + signedRequestToken := in.Msg.GetSignedRequestToken() token, err := jwt.ParseInsecure([]byte(signedRequestToken)) if err != nil { @@ -2045,11 +2049,11 @@ func (f *FakeKas) Rewrap(_ context.Context, in *kaspb.RewrapRequest) (*kaspb.Rew } result := f.getRewrapResponse(requestBodyStr) - return result, nil + return connect.NewResponse(result), nil } -func (f *FakeKas) PublicKey(_ context.Context, _ *kaspb.PublicKeyRequest) (*kaspb.PublicKeyResponse, error) { - return &kaspb.PublicKeyResponse{PublicKey: f.KASInfo.PublicKey, Kid: f.KID}, nil +func (f *FakeKas) PublicKey(_ context.Context, _ *connect.Request[kaspb.PublicKeyRequest]) (*connect.Response[kaspb.PublicKeyResponse], error) { + return connect.NewResponse(&kaspb.PublicKeyResponse{PublicKey: f.KASInfo.PublicKey, Kid: f.KID}), nil } func (f *FakeKas) getRewrapResponse(rewrapRequest string) *kaspb.RewrapResponse { diff --git a/service/authorization/authorization.go b/service/authorization/authorization.go index b88ce67cdf..e888be4ea4 100644 --- a/service/authorization/authorization.go +++ b/service/authorization/authorization.go @@ -197,7 +197,7 @@ func (as *AuthorizationService) GetDecisionsByToken(ctx context.Context, req *co // for each token decision request for _, tdr := range req.Msg.GetDecisionRequests() { - ecResp, err := as.sdk.EntityResoution.CreateEntityChainFromJwt(ctx, &entityresolution.CreateEntityChainFromJwtRequest{Tokens: tdr.GetTokens()}) + ecResp, err := as.sdk.EntityResoution.CreateEntityChainFromJwt(ctx, connect.NewRequest(&entityresolution.CreateEntityChainFromJwtRequest{Tokens: tdr.GetTokens()})) if err != nil { as.logger.Error("Error calling ERS to get entity chains from jwts") return nil, err @@ -206,7 +206,7 @@ func (as *AuthorizationService) GetDecisionsByToken(ctx context.Context, req *co // form a decision request for the token decision request decisionsRequests = append(decisionsRequests, &authorization.DecisionRequest{ Actions: tdr.GetActions(), - EntityChains: ecResp.GetEntityChains(), + EntityChains: ecResp.Msg.GetEntityChains(), ResourceAttributes: tdr.GetResourceAttributes(), }) } @@ -549,19 +549,19 @@ func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connec // If quantity of attributes exceeds maximum list pagination, all are needed to determine entitlements for { - listed, err := as.sdk.Attributes.ListAttributes(ctx, &attr.ListAttributesRequest{ + listed, err := as.sdk.Attributes.ListAttributes(ctx, connect.NewRequest(&attr.ListAttributesRequest{ State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, Pagination: &policy.PageRequest{ Offset: nextOffset, }, - }) + })) if err != nil { as.logger.ErrorContext(ctx, "failed to list attributes", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list attributes")) } - nextOffset = listed.GetPagination().GetNextOffset() - attrsList = append(attrsList, listed.GetAttributes()...) + nextOffset = listed.Msg.GetPagination().GetNextOffset() + attrsList = append(attrsList, listed.Msg.GetAttributes()...) // offset becomes zero when list is exhausted if nextOffset <= 0 { @@ -572,18 +572,18 @@ func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connec // If quantity of subject mappings exceeds maximum list pagination, all are needed to determine entitlements nextOffset = 0 for { - listed, err := as.sdk.SubjectMapping.ListSubjectMappings(ctx, &subjectmapping.ListSubjectMappingsRequest{ + listed, err := as.sdk.SubjectMapping.ListSubjectMappings(ctx, connect.NewRequest(&subjectmapping.ListSubjectMappingsRequest{ Pagination: &policy.PageRequest{ Offset: nextOffset, }, - }) + })) if err != nil { as.logger.ErrorContext(ctx, "failed to list subject mappings", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list subject mappings")) } - nextOffset = listed.GetPagination().GetNextOffset() - subjectMappingsList = append(subjectMappingsList, listed.GetSubjectMappings()...) + nextOffset = listed.Msg.GetPagination().GetNextOffset() + subjectMappingsList = append(subjectMappingsList, listed.Msg.GetSubjectMappings()...) // offset becomes zero when list is exhausted if nextOffset <= 0 { @@ -611,14 +611,14 @@ func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connec } // call ERS on all entities - ersResp, err := as.sdk.EntityResoution.ResolveEntities(ctx, &entityresolution.ResolveEntitiesRequest{Entities: req.Msg.GetEntities()}) + ersResp, err := as.sdk.EntityResoution.ResolveEntities(ctx, connect.NewRequest(&entityresolution.ResolveEntitiesRequest{Entities: req.Msg.GetEntities()})) if err != nil { as.logger.ErrorContext(ctx, "error calling ERS to resolve entities", "entities", req.Msg.GetEntities()) return nil, err } // call rego on all entities - in, err := entitlements.OpaInput(subjectMappings, ersResp) + in, err := entitlements.OpaInput(subjectMappings, ersResp.Msg) if err != nil { as.logger.ErrorContext(ctx, "failed to build rego input", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, errors.New("failed to build rego input")) @@ -725,16 +725,16 @@ func retrieveAttributeDefinitions(ctx context.Context, attrFqns []string, sdk *o return make(map[string]*attr.GetAttributeValuesByFqnsResponse_AttributeAndValue), nil } - resp, err := sdk.Attributes.GetAttributeValuesByFqns(ctx, &attr.GetAttributeValuesByFqnsRequest{ + resp, err := sdk.Attributes.GetAttributeValuesByFqns(ctx, connect.NewRequest(&attr.GetAttributeValuesByFqnsRequest{ WithValue: &policy.AttributeValueSelector{ WithSubjectMaps: false, }, Fqns: attrFqns, - }) + })) if err != nil { return nil, err } - return resp.GetFqnAttributeValues(), nil + return resp.Msg.GetFqnAttributeValues(), nil } func getComprehensiveHierarchy(attributesMap map[string]*policy.Attribute, avf *attr.GetAttributeValuesByFqnsResponse, entitlement string, as *AuthorizationService, entitlements []string) []string { diff --git a/service/authorization/authorization_test.go b/service/authorization/authorization_test.go index 5b69bdc72d..737fca31ac 100644 --- a/service/authorization/authorization_test.go +++ b/service/authorization/authorization_test.go @@ -16,13 +16,14 @@ import ( "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" attr "github.com/opentdf/platform/protocol/go/policy/attributes" + attrconnect "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" + smconnect "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" otdf "github.com/opentdf/platform/sdk" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" @@ -47,15 +48,15 @@ var ( ) type myAttributesClient struct { - attr.AttributesServiceClient + attrconnect.AttributesServiceClient } -func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { - return &listAttributeResp, errListAttributes +func (*myAttributesClient) ListAttributes(_ context.Context, _ *connect.Request[attr.ListAttributesRequest]) (*connect.Response[attr.ListAttributesResponse], error) { + return connect.NewResponse(&listAttributeResp), errListAttributes } -func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attr.GetAttributeValuesByFqnsResponse, error) { - return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns +func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *connect.Request[attr.GetAttributeValuesByFqnsRequest]) (*connect.Response[attr.GetAttributeValuesByFqnsResponse], error) { + return connect.NewResponse(&getAttributesByValueFqnsResponse), errGetAttributesByValueFqns } type myERSClient struct { @@ -63,23 +64,23 @@ type myERSClient struct { } type mySubjectMappingClient struct { - sm.SubjectMappingServiceClient + smconnect.SubjectMappingServiceClient } type paginatedMockSubjectMappingClient struct { - sm.SubjectMappingServiceClient + smconnect.SubjectMappingServiceClient } -func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { - return &listSubjectMappings, nil +func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *connect.Request[sm.ListSubjectMappingsRequest]) (*connect.Response[sm.ListSubjectMappingsResponse], error) { + return connect.NewResponse(&listSubjectMappings), nil } -func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest, _ ...grpc.CallOption) (*entityresolution.CreateEntityChainFromJwtResponse, error) { - return &createEntityChainResp, nil +func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *connect.Request[entityresolution.CreateEntityChainFromJwtRequest]) (*connect.Response[entityresolution.CreateEntityChainFromJwtResponse], error) { + return connect.NewResponse(&createEntityChainResp), nil } -func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest, _ ...grpc.CallOption) (*entityresolution.ResolveEntitiesResponse, error) { - return &resolveEntitiesResp, nil +func (*myERSClient) ResolveEntities(_ context.Context, _ *connect.Request[entityresolution.ResolveEntitiesRequest]) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) { + return connect.NewResponse(&resolveEntitiesResp), nil } var ( @@ -87,7 +88,7 @@ var ( smListCallCount = 0 ) -func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { +func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *connect.Request[sm.ListSubjectMappingsRequest]) (*connect.Response[sm.ListSubjectMappingsResponse], error) { smListCallCount++ // simulate paginated list and policy LIST behavior if smPaginationOffset > 0 { @@ -98,13 +99,13 @@ func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, }, } smPaginationOffset = 0 - return rsp, nil + return connect.NewResponse(rsp), nil } - return &listSubjectMappings, nil + return connect.NewResponse(&listSubjectMappings), nil } type paginatedMockAttributesClient struct { - attr.AttributesServiceClient + attrconnect.AttributesServiceClient } var ( @@ -112,7 +113,7 @@ var ( attrListCallCount = 0 ) -func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { +func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *connect.Request[attr.ListAttributesRequest]) (*connect.Response[attr.ListAttributesResponse], error) { attrListCallCount++ // simulate paginated list and policy LIST behavior if attrPaginationOffset > 0 { @@ -123,9 +124,9 @@ func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr. }, } attrPaginationOffset = 0 - return rsp, nil + return connect.NewResponse(rsp), nil } - return &listAttributeResp, nil + return connect.NewResponse(&listAttributeResp), nil } func TestGetComprehensiveHierarchy(t *testing.T) { diff --git a/service/internal/server/memhttp/listener.go b/service/internal/server/memhttp/listener.go index 1ec303ddba..cf1d381647 100644 --- a/service/internal/server/memhttp/listener.go +++ b/service/internal/server/memhttp/listener.go @@ -56,4 +56,4 @@ func (*memoryAddr) Network() string { return "memory" } // String implements io.Stringer, returning a value that matches the // certificates used by net/http/httptest. -func (*memoryAddr) String() string { return "opentdf.io" } +func (*memoryAddr) String() string { return "http://opentdf.io" } diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 8d6eae604c..6d0e70b3d4 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -19,6 +19,7 @@ import ( "connectrpc.com/validate" "github.com/go-chi/cors" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/opentdf/platform/sdk" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" "github.com/opentdf/platform/service/internal/security" @@ -134,6 +135,7 @@ type OpenTDFServer struct { // To Deprecate: Use the TrustKeyIndex and TrustKeyManager instead CryptoProvider security.CryptoProvider + Listener net.Listener logger *logger.Logger } @@ -230,6 +232,12 @@ func NewOpenTDFServer(config Config, logger *logger.Logger) (*OpenTDFServer, err logger: logger, } + listener, err := o.openHTTPServerPort() + if err != nil { + return nil, err + } + o.Listener = listener + if !config.CryptoProvider.IsEmpty() { // Create crypto provider logger.Info("creating crypto provider", slog.String("type", config.CryptoProvider.Type)) @@ -440,12 +448,7 @@ func (s OpenTDFServer) Start() error { s.ConnectRPCInProcess.Mux.Handle(grpcreflect.NewHandlerV1(reflector)) s.ConnectRPCInProcess.Mux.Handle(grpcreflect.NewHandlerV1Alpha(reflector)) - // Start Http Server - ln, err := s.openHTTPServerPort() - if err != nil { - return err - } - go s.startHTTPServer(ln) + go s.startHTTPServer(s.Listener) return nil } @@ -468,7 +471,24 @@ func (s OpenTDFServer) Stop() { s.logger.Info("shutdown complete") } -func (s inProcessServer) Conn() *grpc.ClientConn { +func (s inProcessServer) Conn() *sdk.ConnectRpcConnection { + var clientInterceptors []connect.Interceptor + + // Add audit interceptor + clientInterceptors = append(clientInterceptors, sdkAudit.MetadataAddingConnectInterceptor()) + + conn := sdk.ConnectRpcConnection{ + Client: s.srv.Client(), + Endpoint: s.srv.Listener.Addr().String(), + Options: []connect.ClientOption{ + connect.WithInterceptors(clientInterceptors...), + connect.WithReadMaxBytes(s.maxCallRecvMsgSize), + connect.WithSendMaxBytes(s.maxCallSendMsgSize)}, + } + return &conn +} + +func (s inProcessServer) GrpcConn() *grpc.ClientConn { var clientInterceptors []grpc.UnaryClientInterceptor // Add audit interceptor diff --git a/service/kas/access/accessPdp.go b/service/kas/access/accessPdp.go index 218ddf39ee..8a684d5ace 100644 --- a/service/kas/access/accessPdp.go +++ b/service/kas/access/accessPdp.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/service/tracing" @@ -77,10 +78,10 @@ func (p *Provider) checkAttributes(ctx context.Context, ras []*authorization.Res } ctx = tracing.InjectTraceContext(ctx) - dr, err := p.SDK.Authorization.GetDecisionsByToken(ctx, &in) + dr, err := p.SDK.Authorization.GetDecisionsByToken(ctx, connect.NewRequest(&in)) if err != nil { p.Logger.ErrorContext(ctx, "Error received from GetDecisionsByToken", "err", err) return nil, errors.Join(ErrDecisionUnexpected, err) } - return dr, nil + return dr.Msg, nil } diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 6f9eed1616..f5c45bd325 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -190,9 +190,11 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF } // Register GRPC Gateway Handler using the in-process connect rpc - if err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, otdf.ConnectRPCInProcess.Conn()); err != nil { + grpcConn := otdf.ConnectRPCInProcess.GrpcConn() + if err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, grpcConn); err != nil { logger.Info("service did not register a grpc gateway handler", slog.String("namespace", ns)) } + defer grpcConn.Close() // Register Extra Handlers if err := svc.RegisterHTTPHandlers(ctx, otdf.GRPCGatewayMux); err != nil { diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index cea8f3c33f..9fc643740f 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -6,13 +6,12 @@ import ( "errors" "fmt" "log/slog" - "net" - "net/url" "os" "os/signal" "slices" "syscall" + "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk" sdkauth "github.com/opentdf/platform/sdk/auth" @@ -25,9 +24,6 @@ import ( "github.com/opentdf/platform/service/pkg/serviceregistry" "github.com/opentdf/platform/service/tracing" wellknown "github.com/opentdf/platform/service/wellknownconfiguration" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" ) const devModeMessage = ` @@ -90,6 +86,17 @@ func Start(f ...StartOptions) error { logger.Debug("config loaded", slog.Any("config", cfg.LogValue())) + // If the mode is not all, does not include both core and entityresolution, or is not entityresolution on its own, we need to have a valid SDK config + // entityresolution does not connect to other services and can run on its own + // core only connects to entityresolution + if !(slices.Contains(cfg.Mode, "all") || // no config required for all mode + (slices.Contains(cfg.Mode, "core") && slices.Contains(cfg.Mode, "entityresolution")) || // or core and entityresolution modes togethor + (slices.Contains(cfg.Mode, "entityresolution") && len(cfg.Mode) == 1)) && // or entityresolution on its own + cfg.SDKConfig == (config.SDKConfig{}) { + logger.Error("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") + return errors.New("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") + } + logger.Info("starting opentdf services") // Set allowed public routes when platform is being extended @@ -191,17 +198,6 @@ func Start(f ...StartOptions) error { oidcconfig *auth.OIDCConfiguration ) - // If the mode is not all, does not include both core and entityresolution, or is not entityresolution on its own, we need to have a valid SDK config - // entityresolution does not connect to other services and can run on its own - // core only connects to entityresolution - if !(slices.Contains(cfg.Mode, "all") || // no config required for all mode - (slices.Contains(cfg.Mode, "core") && slices.Contains(cfg.Mode, "entityresolution")) || // or core and entityresolution modes togethor - (slices.Contains(cfg.Mode, "entityresolution") && len(cfg.Mode) == 1)) && // or entityresolution on its own - cfg.SDKConfig == (config.SDKConfig{}) { - logger.Error("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") - return errors.New("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") - } - // If client credentials are provided, use them if cfg.SDKConfig.ClientID != "" && cfg.SDKConfig.ClientSecret != "" { sdkOptions = append(sdkOptions, sdk.WithClientCredentials(cfg.SDKConfig.ClientID, cfg.SDKConfig.ClientSecret, nil)) @@ -231,18 +227,19 @@ func Start(f ...StartOptions) error { return errors.New("entityresolution endpoint must be provided in core mode") } - ersDialOptions := []grpc.DialOption{} + ersConnectRpcConn := sdk.ConnectRpcConnection{} + var tlsConfig *tls.Config if cfg.SDKConfig.EntityResolutionConnection.Insecure { tlsConfig = &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, // #nosec G402 } - ersDialOptions = append(ersDialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + ersConnectRpcConn.Client = httputil.SafeHTTPClientWithTLSConfig(tlsConfig) } if cfg.SDKConfig.EntityResolutionConnection.Plaintext { tlsConfig = &tls.Config{} - ersDialOptions = append(ersDialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) + ersConnectRpcConn.Client = httputil.SafeHTTPClient() } if cfg.SDKConfig.ClientID != "" && cfg.SDKConfig.ClientSecret != "" { @@ -268,30 +265,16 @@ func Start(f ...StartOptions) error { interceptor := sdkauth.NewTokenAddingInterceptorWithClient(ts, httputil.SafeHTTPClientWithTLSConfig(tlsConfig)) - ersDialOptions = append(ersDialOptions, grpc.WithChainUnaryInterceptor(interceptor.AddCredentials)) + ersConnectRpcConn.Options = append(ersConnectRpcConn.Options, connect.WithInterceptors(interceptor.AddCredentialsConnect())) } - parsedURL, err := url.Parse(cfg.SDKConfig.EntityResolutionConnection.Endpoint) - if err != nil { - return fmt.Errorf("cannot parse ers url(%s): %w", cfg.SDKConfig.EntityResolutionConnection.Endpoint, err) - } - // Needed to support buffconn for testing - if parsedURL.Host == "" { - return errors.New("ERS host is empty when parsing") - } - port := parsedURL.Port() - // if port is empty, default to 443. - if port == "" { - port = "443" + if sdk.IsPlatformEndpointMalformed(cfg.SDKConfig.EntityResolutionConnection.Endpoint) { + return fmt.Errorf("entityresolution endpoint is malformed: %s", cfg.SDKConfig.EntityResolutionConnection.Endpoint) } - ersGRPCEndpoint := net.JoinHostPort(parsedURL.Hostname(), port) + ersConnectRpcConn.Endpoint = cfg.SDKConfig.EntityResolutionConnection.Endpoint - conn, err := grpc.NewClient(ersGRPCEndpoint, ersDialOptions...) - if err != nil { - return fmt.Errorf("could not connect to ERS: %w", err) - } - sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(conn)) - logger.Info("added with custom ers connection for ", "", ersGRPCEndpoint) + sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(&ersConnectRpcConn)) + logger.Info("added with custom ers connection for ", "", ersConnectRpcConn.Endpoint) } client, err = sdk.New("", sdkOptions...) @@ -314,8 +297,6 @@ func Start(f ...StartOptions) error { } } - defer client.Close() - logger.Info("starting services") err = startServices(ctx, cfg, otdf, client, logger, svcRegistry) if err != nil { diff --git a/service/rttests/rt_test.go b/service/rttests/rt_test.go index aa0a3b52b5..a8e987a975 100644 --- a/service/rttests/rt_test.go +++ b/service/rttests/rt_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" "github.com/opentdf/platform/protocol/go/policy/namespaces" @@ -166,12 +167,12 @@ func (s *RoundtripSuite) CreateTestData() error { // create namespace example.com var exampleNamespace *policy.Namespace slog.Info("listing namespaces") - listResp, err := client.Namespaces.ListNamespaces(context.Background(), &namespaces.ListNamespacesRequest{}) + listResp, err := client.Namespaces.ListNamespaces(context.Background(), connect.NewRequest(&namespaces.ListNamespacesRequest{})) if err != nil { return err } - slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.GetNamespaces()))) - for _, ns := range listResp.GetNamespaces() { + slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Msg.GetNamespaces()))) + for _, ns := range listResp.Msg.GetNamespaces() { slog.Info(fmt.Sprintf("existing namespace; name: %s, id: %s", ns.GetName(), ns.GetId())) if ns.GetName() == "example.com" { exampleNamespace = ns @@ -180,20 +181,20 @@ func (s *RoundtripSuite) CreateTestData() error { if exampleNamespace == nil { slog.Info("creating new namespace") - resp, err := client.Namespaces.CreateNamespace(context.Background(), &namespaces.CreateNamespaceRequest{ + resp, err := client.Namespaces.CreateNamespace(context.Background(), connect.NewRequest(&namespaces.CreateNamespaceRequest{ Name: "example.com", - }) + })) if err != nil { return err } - exampleNamespace = resp.GetNamespace() + exampleNamespace = resp.Msg.GetNamespace() } slog.Info("##################################\n#######################################") // Create the attributes slog.Info("creating attribute language with allOf rule") - _, err = client.Attributes.CreateAttribute(context.Background(), &attributes.CreateAttributeRequest{ + _, err = client.Attributes.CreateAttribute(context.Background(), connect.NewRequest(&attributes.CreateAttributeRequest{ Name: "language", NamespaceId: exampleNamespace.GetId(), Rule: *policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF.Enum(), @@ -202,7 +203,7 @@ func (s *RoundtripSuite) CreateTestData() error { "french", "spanish", }, - }) + })) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("attribute already exists") @@ -215,7 +216,7 @@ func (s *RoundtripSuite) CreateTestData() error { } slog.Info("creating attribute color with anyOf rule") - _, err = client.Attributes.CreateAttribute(context.Background(), &attributes.CreateAttributeRequest{ + _, err = client.Attributes.CreateAttribute(context.Background(), connect.NewRequest(&attributes.CreateAttributeRequest{ Name: "color", NamespaceId: exampleNamespace.GetId(), Rule: *policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF.Enum(), @@ -224,7 +225,7 @@ func (s *RoundtripSuite) CreateTestData() error { "green", "blue", }, - }) + })) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("attribute already exists") @@ -237,7 +238,7 @@ func (s *RoundtripSuite) CreateTestData() error { } slog.Info("creating attribute cards with hierarchy rule") - _, err = client.Attributes.CreateAttribute(context.Background(), &attributes.CreateAttributeRequest{ + _, err = client.Attributes.CreateAttribute(context.Background(), connect.NewRequest(&attributes.CreateAttributeRequest{ Name: "cards", NamespaceId: exampleNamespace.GetId(), Rule: *policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_HIERARCHY.Enum(), @@ -246,7 +247,7 @@ func (s *RoundtripSuite) CreateTestData() error { "queen", "jack", }, - }) + })) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("attribute already exists") @@ -260,33 +261,33 @@ func (s *RoundtripSuite) CreateTestData() error { slog.Info("##################################\n#######################################") - allAttr, err := client.Attributes.ListAttributes(context.Background(), &attributes.ListAttributesRequest{}) + allAttr, err := client.Attributes.ListAttributes(context.Background(), connect.NewRequest(&attributes.ListAttributesRequest{})) if err != nil { slog.Error("could not list attributes", slog.String("error", err.Error())) return err } - slog.Info(fmt.Sprintf("list attributes response: %s", protojson.Format(allAttr))) + slog.Info(fmt.Sprintf("list attributes response: %s", protojson.Format(allAttr.Msg))) slog.Info("##################################\n#######################################") // get the attribute ids for the values were mapping to the client var attributeValueIDs []string - fqnResp, err := client.Attributes.GetAttributeValuesByFqns(context.Background(), &attributes.GetAttributeValuesByFqnsRequest{ + fqnResp, err := client.Attributes.GetAttributeValuesByFqns(context.Background(), connect.NewRequest(&attributes.GetAttributeValuesByFqnsRequest{ Fqns: attributesToMap, WithValue: &policy.AttributeValueSelector{}, - }) + })) if err != nil { slog.Error("get attribute values by fqn ", slog.String("error", err.Error())) return err } for _, attribute := range attributesToMap { - attributeValueIDs = append(attributeValueIDs, fqnResp.GetFqnAttributeValues()[attribute].GetValue().GetId()) + attributeValueIDs = append(attributeValueIDs, fqnResp.Msg.GetFqnAttributeValues()[attribute].GetValue().GetId()) } // create subject mappings slog.Info("creating subject mappings for client " + s.TestConfig.ClientID) for _, attributeID := range attributeValueIDs { - _, err = client.SubjectMapping.CreateSubjectMapping(context.Background(), &subjectmapping.CreateSubjectMappingRequest{ + _, err = client.SubjectMapping.CreateSubjectMapping(context.Background(), connect.NewRequest(&subjectmapping.CreateSubjectMappingRequest{ AttributeValueId: attributeID, Actions: []*policy.Action{ {Name: actions.ActionNameCreate}, @@ -306,7 +307,7 @@ func (s *RoundtripSuite) CreateTestData() error { }}, }, }, - }) + })) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("subject mapping already exists") @@ -319,12 +320,12 @@ func (s *RoundtripSuite) CreateTestData() error { } } - allSubMaps, err := client.SubjectMapping.ListSubjectMappings(context.Background(), &subjectmapping.ListSubjectMappingsRequest{}) + allSubMaps, err := client.SubjectMapping.ListSubjectMappings(context.Background(), connect.NewRequest(&subjectmapping.ListSubjectMappingsRequest{})) if err != nil { slog.Error("could not list subject mappings", slog.String("error", err.Error())) return err } - slog.Info(fmt.Sprintf("list subject mappings response: %s", protojson.Format(allSubMaps))) + slog.Info(fmt.Sprintf("list subject mappings response: %s", protojson.Format(allSubMaps.Msg))) return nil } From d6fc97def98836270c6eb1a9372375c953e654fe Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 9 May 2025 15:31:42 -0400 Subject: [PATCH 02/31] lint examples --- examples/cmd/attributes.go | 8 ++++---- examples/cmd/authorization.go | 2 +- examples/cmd/kas.go | 2 +- sdk/options.go | 26 ++++++++------------------ 4 files changed, 14 insertions(+), 24 deletions(-) diff --git a/examples/cmd/attributes.go b/examples/cmd/attributes.go index ca78eb8ca0..ec258bbcb7 100644 --- a/examples/cmd/attributes.go +++ b/examples/cmd/attributes.go @@ -116,7 +116,7 @@ func listAttributes(cmd *cobra.Command) error { if err != nil { return err } - slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Msg.Namespaces))) + slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Msg.GetNamespaces()))) for _, n := range listResp.Msg.GetNamespaces() { nsuris = append(nsuris, n.GetFqn()) } @@ -215,7 +215,7 @@ func addNamespace(ctx context.Context, s *sdk.SDK, u string) (string, error) { slog.Error("CreateNamespace", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - return resp.Msg.Namespace.GetId(), nil + return resp.Msg.GetNamespace().GetId(), nil } func addAttribute(cmd *cobra.Command) error { @@ -373,7 +373,7 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { if err != nil { return err } - for _, b := range ar.Msg.Attribute.GetGrants() { + for _, b := range ar.Msg.GetAttribute().GetGrants() { kasids = append(kasids, b.GetId()) kasById[b.GetId()] = b.GetUri() } @@ -480,5 +480,5 @@ func upsertAttr(ctx context.Context, s *sdk.SDK, auth, name string, values []str slog.Error("CreateAttribute", "err", err, "auth", auth, "name", name, "values", values, "rule", ruler()) return "", err } - return av.Msg.Attribute.GetId(), nil + return av.Msg.GetAttribute().GetId(), nil } diff --git a/examples/cmd/authorization.go b/examples/cmd/authorization.go index afb3d54bcb..0186a98505 100644 --- a/examples/cmd/authorization.go +++ b/examples/cmd/authorization.go @@ -70,7 +70,7 @@ func authorizationExamples() error { // map response back to entity chain id decisionsByEntityChain := make(map[string]*authorization.DecisionResponse) - for _, dr := range decisionResponse.Msg.DecisionResponses { + for _, dr := range decisionResponse.Msg.GetDecisionResponses() { decisionsByEntityChain[dr.EntityChainId] = dr } diff --git a/examples/cmd/kas.go b/examples/cmd/kas.go index 0939a0dca6..983b0d19f5 100644 --- a/examples/cmd/kas.go +++ b/examples/cmd/kas.go @@ -138,7 +138,7 @@ func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *poli slog.Error("CreateKeyAccessServer", "uri", uri, "publicKey", uri+"/v2/kas_public_key") return "", err } - return ur.Msg.KeyAccessServer.GetId(), nil + return ur.Msg.GetKeyAccessServer().GetId(), nil } func algString2Proto(a string) policy.KasPublicKeyAlgEnum { diff --git a/sdk/options.go b/sdk/options.go index 4e25fc7e2b..eb5179e15b 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -24,14 +24,13 @@ type ConnectRpcConnection struct { // Internal config struct for building SDK options. type config struct { // Platform configuration structure is subject to change. Consume via accessor methods. - PlatformConfiguration PlatformConfiguration - extraClientOptions []connect.ClientOption - httpClient *http.Client - clientCredentials *oauth.ClientCredentials - tokenExchange *oauth.TokenExchangeInfo - tokenEndpoint string - scopes []string - // extraDialOptions []grpc.DialOption + PlatformConfiguration PlatformConfiguration + extraClientOptions []connect.ClientOption + httpClient *http.Client + clientCredentials *oauth.ClientCredentials + tokenExchange *oauth.TokenExchangeInfo + tokenEndpoint string + scopes []string certExchange *oauth.CertExchangeInfo kasSessionKey *ocrypto.RsaKeyPair dpopKey *ocrypto.RsaKeyPair @@ -60,10 +59,6 @@ type nanoFeatures struct { type PlatformConfiguration map[string]interface{} -// func (c *config) build() []grpc.DialOption { -// return []grpc.DialOption{c.dialOption} -// } - // WithInsecureSkipVerifyConn returns an Option that sets up HTTPS connection without verification. func WithInsecureSkipVerifyConn() Option { return func(c *config) { @@ -157,12 +152,6 @@ func WithTokenExchange(subjectToken string, audience []string) Option { } } -// func WithExtraDialOptions(dialOptions ...grpc.DialOption) Option { -// return func(c *config) { -// c.extraDialOptions = dialOptions -// } -// } - // The session key pair is used to encrypt responses from KAS for a given session // and can be reused across an entire session. // Please use with caution. @@ -227,6 +216,7 @@ func WithCustomCoreConnection(conn *ConnectRpcConnection) Option { } } +// WithExtraClientOptions returns an Option that adds extra connect rpc client options to the conect rpc clients func WithExtraClientOptions(opts ...connect.ClientOption) Option { return func(c *config) { c.extraClientOptions = opts From 515444e8e399413a7c02d73d2fc59c2e1937838c Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 9 May 2025 15:37:11 -0400 Subject: [PATCH 03/31] sdk linting --- sdk/audit/metadata_adding_interceptor.go | 4 +- sdk/options.go | 16 ++--- sdk/sdk.go | 16 ++--- sdk/tdf_test.go | 84 ++++++++++++------------ service/internal/server/server.go | 4 +- service/pkg/server/start.go | 2 +- 6 files changed, 63 insertions(+), 63 deletions(-) diff --git a/sdk/audit/metadata_adding_interceptor.go b/sdk/audit/metadata_adding_interceptor.go index 0f05de6a5f..b4bc9fb20d 100644 --- a/sdk/audit/metadata_adding_interceptor.go +++ b/sdk/audit/metadata_adding_interceptor.go @@ -64,12 +64,12 @@ func MetadataAddingConnectInterceptor() connect.UnaryInterceptorFunc { req.Header().Set(string(RequestIDHeaderKey), requestID.String()) // Add the request IP to a custom header so it is preserved - if requestIP, ok := ctx.Value(RequestIPContextKey).(string); ok { + if requestIP, okIP := ctx.Value(RequestIPContextKey).(string); okIP { req.Header().Set(string(RequestIPHeaderKey), requestIP) } // Add the actor ID from the request so it is preserved if we need it - if actorID, ok := ctx.Value(ActorIDContextKey).(string); ok { + if actorID, okAct := ctx.Value(ActorIDContextKey).(string); okAct { req.Header().Set(string(ActorIDHeaderKey), actorID) } diff --git a/sdk/options.go b/sdk/options.go index eb5179e15b..7ad993388d 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -15,7 +15,7 @@ import ( type Option func(*config) -type ConnectRpcConnection struct { +type ConnectRPCConnection struct { Client *http.Client Endpoint string Options []connect.ClientOption @@ -39,8 +39,8 @@ type config struct { nanoFeatures nanoFeatures customAccessTokenSource auth.AccessTokenSource oauthAccessTokenSource oauth2.TokenSource - coreConn *ConnectRpcConnection - entityResolutionConn *ConnectRpcConnection + coreConn *ConnectRPCConnection + entityResolutionConn *ConnectRPCConnection collectionStore *collectionStore shouldValidatePlatformConnectivity bool } @@ -122,20 +122,20 @@ func WithOAuthAccessTokenSource(t oauth2.TokenSource) Option { } // Deprecated: Use WithCustomCoreConnection instead -func WithCustomPolicyConnection(conn *ConnectRpcConnection) Option { +func WithCustomPolicyConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } } // Deprecated: Use WithCustomCoreConnection instead -func WithCustomAuthorizationConnection(conn *ConnectRpcConnection) Option { +func WithCustomAuthorizationConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } } -func WithCustomEntityResolutionConnection(conn *ConnectRpcConnection) Option { +func WithCustomEntityResolutionConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.entityResolutionConn = conn } @@ -172,7 +172,7 @@ func WithSessionSignerRSA(key *rsa.PrivateKey) Option { } } -func WithCustomWellknownConnection(conn *ConnectRpcConnection) Option { +func WithCustomWellknownConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } @@ -210,7 +210,7 @@ func WithNoKIDInKAO() Option { } // WithCoreConnection returns an Option that sets up a connection to the core platform -func WithCustomCoreConnection(conn *ConnectRpcConnection) Option { +func WithCustomCoreConnection(conn *ConnectRPCConnection) Option { return func(c *config) { c.coreConn = conn } diff --git a/sdk/sdk.go b/sdk/sdk.go index 0d852169e9..63b65bf641 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -63,7 +63,7 @@ type SDK struct { config *kasKeyCache *collectionStore - conn *ConnectRpcConnection + conn *ConnectRPCConnection tokenSource auth.AccessTokenSource Actions actionsconnect.ActionServiceClient Attributes attributesconnect.AttributesServiceClient @@ -81,8 +81,8 @@ type SDK struct { func New(platformEndpoint string, opts ...Option) (*SDK, error) { var ( - platformConn *ConnectRpcConnection // Connection to the platform - ersConn *ConnectRpcConnection // Connection to ERS (possibly remote) + platformConn *ConnectRPCConnection // Connection to the platform + ersConn *ConnectRPCConnection // Connection to ERS (possibly remote) err error ) @@ -144,7 +144,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { return nil, errors.Join(ErrPlatformConfigFailed, err) } } else { - pcfg, err = getPlatformConfiguration(&ConnectRpcConnection{Endpoint: platformEndpoint, Client: cfg.httpClient}) + pcfg, err = getPlatformConfiguration(&ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient}) if err != nil { return nil, errors.Join(ErrPlatformConfigFailed, err) } @@ -179,7 +179,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { if cfg.coreConn != nil { platformConn = cfg.coreConn } else { - platformConn = &ConnectRpcConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: append(cfg.extraClientOptions, connect.WithInterceptors(uci...))} + platformConn = &ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: append(cfg.extraClientOptions, connect.WithInterceptors(uci...))} } if cfg.entityResolutionConn != nil { @@ -192,7 +192,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { config: *cfg, collectionStore: cfg.collectionStore, kasKeyCache: newKasKeyCache(), - conn: &ConnectRpcConnection{Client: platformConn.Client, Endpoint: platformConn.Endpoint, Options: platformConn.Options}, + conn: &ConnectRPCConnection{Client: platformConn.Client, Endpoint: platformConn.Endpoint, Options: platformConn.Options}, tokenSource: accessTokenSource, Actions: actionsconnect.NewActionServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), Attributes: attributesconnect.NewAttributesServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), @@ -268,7 +268,7 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { } // Conn returns the underlying http connection -func (s SDK) Conn() *ConnectRpcConnection { +func (s SDK) Conn() *ConnectRPCConnection { return s.conn } @@ -411,7 +411,7 @@ func ValidateHealthyPlatformConnection(platformEndpoint string, httpClient *http return nil } -func getPlatformConfiguration(conn *ConnectRpcConnection) (PlatformConfiguration, error) { +func getPlatformConfiguration(conn *ConnectRPCConnection) (PlatformConfiguration, error) { req := wellknownconfiguration.GetWellKnownConfigurationRequest{} wellKnownConfig := wellknownconfigurationconnect.NewWellKnownServiceClient(conn.Client, conn.Endpoint) diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 8b10e7b544..b71f83eccf 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -288,7 +288,7 @@ type TDFSuite struct { suite.Suite sdk *SDK kases []FakeKas - kasTestUrlLookup map[string]string + kasTestURLLookup map[string]string } func (s *TDFSuite) SetupSuite() { @@ -327,20 +327,20 @@ func (s *TDFSuite) Test_SimpleTDF() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), }, tdfReadOptions: []TDFReaderOption{ - WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]}), + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -348,14 +348,14 @@ func (s *TDFSuite) Test_SimpleTDF() { WithTargetMode("0.0.0"), }, tdfReadOptions: []TDFReaderOption{ - WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]}), + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]}), }, useHex: true, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://d.kas/"], + URL: s.kasTestURLLookup["https://d.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -364,13 +364,13 @@ func (s *TDFSuite) Test_SimpleTDF() { }, tdfReadOptions: []TDFReaderOption{ WithSessionKeyType(ocrypto.EC256Key), - WithKasAllowlist([]string{s.kasTestUrlLookup["https://d.kas/"]}), + WithKasAllowlist([]string{s.kasTestURLLookup["https://d.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://d.kas/"], + URL: s.kasTestURLLookup["https://d.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -380,7 +380,7 @@ func (s *TDFSuite) Test_SimpleTDF() { }, tdfReadOptions: []TDFReaderOption{ WithSessionKeyType(ocrypto.EC256Key), - WithKasAllowlist([]string{s.kasTestUrlLookup["https://d.kas/"]}), + WithKasAllowlist([]string{s.kasTestURLLookup["https://d.kas/"]}), }, useHex: true, }, @@ -500,20 +500,20 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), WithDataAttributes(attributes...), }, tdfReadOptions: []TDFReaderOption{ - WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]}), + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]}), }, }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -522,12 +522,12 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { tdfReadOptions: []TDFReaderOption{ WithKasAllowlist([]string{"https://nope-not-a-kas.com/kas"}), }, - expectedError: "KasAllowlist: kas url " + s.kasTestUrlLookup["https://a.kas/"] + " is not allowed", + expectedError: "KasAllowlist: kas url " + s.kasTestURLLookup["https://a.kas/"] + " is not allowed", }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -536,12 +536,12 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { tdfReadOptions: []TDFReaderOption{ withKasAllowlist(AllowList{"nope-not-a-kas.com": true}), }, - expectedError: "KasAllowlist: kas url " + s.kasTestUrlLookup["https://a.kas/"] + " is not allowed", + expectedError: "KasAllowlist: kas url " + s.kasTestURLLookup["https://a.kas/"] + " is not allowed", }, { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -555,7 +555,7 @@ func (s *TDFSuite) Test_TDF_KAS_Allowlist() { { tdfOptions: []TDFOption{ WithKasInformation(KASInfo{ - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }), WithMetaData(string(metaData)), @@ -831,7 +831,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() { { kasURLs := []KASInfo{ { - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }, } @@ -875,12 +875,12 @@ func (s *TDFSuite) Test_TDFWithAssertion() { var r *Reader if test.verifiers == nil { - r, err = s.sdk.LoadTDF(readSeeker, WithDisableAssertionVerification(test.disableAssertionVerification), WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]})) + r, err = s.sdk.LoadTDF(readSeeker, WithDisableAssertionVerification(test.disableAssertionVerification), WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]})) } else { r, err = s.sdk.LoadTDF(readSeeker, WithAssertionVerificationKeys(*test.verifiers), WithDisableAssertionVerification(test.disableAssertionVerification), - WithKasAllowlist([]string{s.kasTestUrlLookup["https://a.kas/"]})) + WithKasAllowlist([]string{s.kasTestURLLookup["https://a.kas/"]})) } s.Require().NoError(err) @@ -1150,7 +1150,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() { { kasURLs := []KASInfo{ { - URL: s.kasTestUrlLookup["https://a.kas/"], + URL: s.kasTestURLLookup["https://a.kas/"], PublicKey: "", }, } @@ -1209,11 +1209,11 @@ func (s *TDFSuite) Test_TDFReader() { //nolint:gocognit // requires for testing payload: payload, // len: 62 kasInfoList: []KASInfo{ { - URL: s.kasTestUrlLookup["http://localhost:65432/"], + URL: s.kasTestURLLookup["http://localhost:65432/"], PublicKey: mockRSAPublicKey1, }, { - URL: s.kasTestUrlLookup["http://localhost:65432/"], + URL: s.kasTestURLLookup["http://localhost:65432/"], PublicKey: mockRSAPublicKey1, }, }, @@ -1309,11 +1309,11 @@ func (s *TDFSuite) Test_TDFReader() { //nolint:gocognit // requires for testing func (s *TDFSuite) Test_TDFReaderFail() { kasInfoList := []KASInfo{ { - URL: s.kasTestUrlLookup["http://localhost:65432/api/kas"], + URL: s.kasTestURLLookup["http://localhost:65432/api/kas"], PublicKey: mockRSAPublicKey1, }, { - URL: s.kasTestUrlLookup["http://localhost:65432/api/kas"], + URL: s.kasTestURLLookup["http://localhost:65432/api/kas"], PublicKey: mockRSAPublicKey1, }, } @@ -1642,9 +1642,9 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 2759, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: s.kasTestUrlLookup["https://a.kas/"], SplitID: "a"}, - {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "a"}, - {KAS: s.kasTestUrlLookup[`https://c.kas/`], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup[`https://c.kas/`], SplitID: "a"}, }, }, { @@ -1653,9 +1653,9 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 2759, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: s.kasTestUrlLookup["https://a.kas/"], SplitID: "a"}, - {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "b"}, - {KAS: s.kasTestUrlLookup["https://c.kas/"], SplitID: "c"}, + {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "b"}, + {KAS: s.kasTestURLLookup["https://c.kas/"], SplitID: "c"}, }, }, { @@ -1664,10 +1664,10 @@ func (s *TDFSuite) Test_KeySplits() { tdfFileSize: 3351, checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", splitPlan: []keySplitStep{ - {KAS: s.kasTestUrlLookup["https://a.kas/"], SplitID: "a"}, - {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "a"}, - {KAS: s.kasTestUrlLookup["https://b.kas/"], SplitID: "b"}, - {KAS: s.kasTestUrlLookup["https://c.kas/"], SplitID: "b"}, + {KAS: s.kasTestURLLookup["https://a.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "a"}, + {KAS: s.kasTestURLLookup["https://b.kas/"], SplitID: "b"}, + {KAS: s.kasTestURLLookup["https://c.kas/"], SplitID: "b"}, }, }, } { @@ -1912,9 +1912,9 @@ func (s *TDFSuite) startBackend() { s.kases = make([]FakeKas, 12) - s.kasTestUrlLookup = make(map[string]string, 12) + s.kasTestURLLookup = make(map[string]string, 12) - var sdkPlatformUrl string + var sdkPlatformURL string for i, ki := range kasesToMake { @@ -1938,18 +1938,18 @@ func (s *TDFSuite) startBackend() { server := httptest.NewServer(mux) // add to lookup reg - s.kasTestUrlLookup[s.kases[i].KASInfo.URL] = server.URL + s.kasTestURLLookup[s.kases[i].KASInfo.URL] = server.URL // replace kasinfo url with httptest server url s.kases[i].KASInfo.URL = server.URL if i == 0 { - sdkPlatformUrl = server.URL + sdkPlatformURL = server.URL } } ats := getTokenSource(s.T()) - sdk, err := New(sdkPlatformUrl, + sdk, err := New(sdkPlatformURL, WithClientCredentials("test", "test", nil), withCustomAccessTokenSource(&ats), WithTokenEndpoint("http://localhost:65432/auth/token"), @@ -1989,7 +1989,7 @@ func (f *FakeAttributes) GetAttributeValuesByFqns(_ context.Context, in *connect } v := mockValueFor(av) for i := range v.GetGrants() { - v.Grants[i].Uri = f.s.kasTestUrlLookup[v.Grants[i].Uri] + v.Grants[i].Uri = f.s.kasTestURLLookup[v.GetGrants()[i].GetUri()] } r[fqn] = &attributespb.GetAttributeValuesByFqnsResponse_AttributeAndValue{ Attribute: v.GetAttribute(), @@ -2014,7 +2014,7 @@ func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *connect.Req for _, k := range f.kases { kas := &policy.KeyAccessServer{ - Uri: f.s.kasTestUrlLookup[k.url], + Uri: f.s.kasTestURLLookup[k.url], } resp.KeyAccessServers = append(resp.KeyAccessServers, kas) } diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 6d0e70b3d4..5583e505ee 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -471,13 +471,13 @@ func (s OpenTDFServer) Stop() { s.logger.Info("shutdown complete") } -func (s inProcessServer) Conn() *sdk.ConnectRpcConnection { +func (s inProcessServer) Conn() *sdk.ConnectRPCConnection { var clientInterceptors []connect.Interceptor // Add audit interceptor clientInterceptors = append(clientInterceptors, sdkAudit.MetadataAddingConnectInterceptor()) - conn := sdk.ConnectRpcConnection{ + conn := sdk.ConnectRPCConnection{ Client: s.srv.Client(), Endpoint: s.srv.Listener.Addr().String(), Options: []connect.ClientOption{ diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 9fc643740f..5aba1e06dd 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -227,7 +227,7 @@ func Start(f ...StartOptions) error { return errors.New("entityresolution endpoint must be provided in core mode") } - ersConnectRpcConn := sdk.ConnectRpcConnection{} + ersConnectRpcConn := sdk.ConnectRPCConnection{} var tlsConfig *tls.Config if cfg.SDKConfig.EntityResolutionConnection.Insecure { From 78503f430b9685de653ccb913ffdfe0245654d78 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 9 May 2025 15:50:49 -0400 Subject: [PATCH 04/31] sdklint, fix grpc gateway --- sdk/sdk.go | 1 - service/pkg/server/services.go | 16 ++++++++++------ service/pkg/server/services_test.go | 4 +++- service/pkg/server/start.go | 3 ++- service/pkg/server/start_test.go | 3 ++- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sdk/sdk.go b/sdk/sdk.go index 63b65bf641..00749f0efc 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -395,7 +395,6 @@ func IsValidNanoTdf(reader io.ReadSeeker) (bool, error) { // Test connectability to the platform and validate a healthy status func ValidateHealthyPlatformConnection(platformEndpoint string, httpClient *http.Client) error { - healthClient := connect.NewClient[healthpb.HealthCheckRequest, healthpb.HealthCheckResponse]( httpClient, platformEndpoint+"/grpc.health.v1.Health/Check", diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index f5c45bd325..2686fec93c 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -112,7 +112,9 @@ func registerCoreServices(reg serviceregistry.Registry, mode []string) ([]string // based on the configuration and namespace mode. It creates a new service logger // and a database client if required. It registers the services with the gRPC server, // in-process gRPC server, and gRPC gateway. Finally, it logs the status of each service. -func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDFServer, client *sdk.SDK, logger *logging.Logger, reg serviceregistry.Registry) error { +func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDFServer, client *sdk.SDK, logger *logging.Logger, reg serviceregistry.Registry) (func(), error) { + var gatewayCleanup func() + // Iterate through the registered namespaces for ns, namespace := range reg { // modeEnabled checks if the mode is enabled based on the configuration and namespace mode. @@ -157,7 +159,7 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF var err error svcDBClient, err = newServiceDBClient(ctx, cfg.Logger, cfg.DB, tracer, ns, svc.DBMigrations()) if err != nil { - return err + return nil, err } } @@ -172,11 +174,11 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF Tracer: tracer, }) if err != nil { - return err + return nil, err } if err := svc.RegisterConfigUpdateHook(ctx, cfg.AddOnConfigChangeHook); err != nil { - return fmt.Errorf("failed to register config update hook: %w", err) + return nil, fmt.Errorf("failed to register config update hook: %w", err) } // Register Connect RPC Services @@ -194,7 +196,9 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF if err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, grpcConn); err != nil { logger.Info("service did not register a grpc gateway handler", slog.String("namespace", ns)) } - defer grpcConn.Close() + gatewayCleanup = func() { + grpcConn.Close() + } // Register Extra Handlers if err := svc.RegisterHTTPHandlers(ctx, otdf.GRPCGatewayMux); err != nil { @@ -213,7 +217,7 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF } } - return nil + return gatewayCleanup, nil } func extractServiceLoggerConfig(cfg config.ServiceConfig) (string, error) { diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index bbe74a75b0..fe0b046da9 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -253,7 +253,7 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { newLogger, err := logger.NewLogger(logger.Config{Output: "stdout", Level: "info", Type: "json"}) suite.Require().NoError(err) - err = startServices(ctx, &config.Config{ + cleanup, err := startServices(ctx, &config.Config{ Mode: []string{"test"}, Logger: logger.Config{Output: "stdout", Level: "info", Type: "json"}, // DB: db.Config{ @@ -270,6 +270,8 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { "foobar": {}, }, }, otdf, nil, newLogger, registry) + defer cleanup() + suite.Require().NoError(err) // require.NotNil(t, cF) // assert.Lenf(t, services, 2, "expected 2 services enabled, got %d", len(services)) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 5aba1e06dd..3f0253c4a7 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -298,11 +298,12 @@ func Start(f ...StartOptions) error { } logger.Info("starting services") - err = startServices(ctx, cfg, otdf, client, logger, svcRegistry) + gatewayCleanup, err := startServices(ctx, cfg, otdf, client, logger, svcRegistry) if err != nil { logger.Error("issue starting services", slog.String("error", err.Error())) return fmt.Errorf("issue starting services: %w", err) } + defer gatewayCleanup() // Start watching the configuration for changes with registered config change service hooks if err := cfg.Watch(ctx); err != nil { diff --git a/service/pkg/server/start_test.go b/service/pkg/server/start_test.go index 8c1581392b..53a63d9d3e 100644 --- a/service/pkg/server/start_test.go +++ b/service/pkg/server/start_test.go @@ -183,12 +183,13 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered_Expect_Res err = registry.RegisterService(registerTestService, "test") suite.Require().NoError(err) // Start services with test service - err = startServices(context.Background(), &config.Config{ + cleanup, err := startServices(context.Background(), &config.Config{ Mode: []string{"all"}, Services: map[string]config.ServiceConfig{ "test": {}, }, }, s, nil, logger, registry) + defer cleanup() require.NoError(t, err) require.NoError(t, s.Start()) From 78ec5f9baf1d2ae47086dcaae00abee0440152e3 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 9 May 2025 16:47:08 -0400 Subject: [PATCH 05/31] linting, default to https in examples --- examples/cmd/examples.go | 5 +++-- service/internal/server/server.go | 3 ++- service/pkg/server/start.go | 14 +++++++------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/cmd/examples.go b/examples/cmd/examples.go index 66d426c041..d3a0db85a5 100644 --- a/examples/cmd/examples.go +++ b/examples/cmd/examples.go @@ -29,8 +29,8 @@ func init() { log.SetFlags(log.LstdFlags | log.Llongfile) f := ExamplesCmd.PersistentFlags() f.StringVarP(&clientCredentials, "creds", "", "opentdf-sdk:secret", "client id:secret credentials") - f.StringVarP(&platformEndpoint, "platformEndpoint", "e", "http://localhost:8080", "Platform Endpoint") - f.StringVarP(&tokenEndpoint, "tokenEndpoint", "t", "http://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token", "OAuth token endpoint") + f.StringVarP(&platformEndpoint, "platformEndpoint", "e", "https://localhost:8080", "Platform Endpoint") + f.StringVarP(&tokenEndpoint, "tokenEndpoint", "t", "https://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token", "OAuth token endpoint") f.BoolVar(&storeCollectionHeaders, "storeCollectionHeaders", false, "Store collection headers") f.BoolVar(&insecurePlaintextConn, "insecurePlaintextConn", false, "Use insecure plaintext connection") f.BoolVar(&insecureSkipVerify, "insecureSkipVerify", false, "Skip server certificate verification") @@ -40,6 +40,7 @@ func newSDK() (*sdk.SDK, error) { resolver.SetDefaultScheme("passthrough") opts := []sdk.Option{} if insecurePlaintextConn { + platformEndpoint = strings.Replace(platformEndpoint, "https://", "http://", 1) opts = append(opts, sdk.WithInsecurePlaintextConn()) } if insecureSkipVerify { diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 5583e505ee..43adff4093 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -483,7 +483,8 @@ func (s inProcessServer) Conn() *sdk.ConnectRPCConnection { Options: []connect.ClientOption{ connect.WithInterceptors(clientInterceptors...), connect.WithReadMaxBytes(s.maxCallRecvMsgSize), - connect.WithSendMaxBytes(s.maxCallSendMsgSize)}, + connect.WithSendMaxBytes(s.maxCallSendMsgSize), + }, } return &conn } diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 3f0253c4a7..4632ce8a27 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -227,7 +227,7 @@ func Start(f ...StartOptions) error { return errors.New("entityresolution endpoint must be provided in core mode") } - ersConnectRpcConn := sdk.ConnectRPCConnection{} + ersConnectRPCConn := sdk.ConnectRPCConnection{} var tlsConfig *tls.Config if cfg.SDKConfig.EntityResolutionConnection.Insecure { @@ -235,11 +235,11 @@ func Start(f ...StartOptions) error { MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, // #nosec G402 } - ersConnectRpcConn.Client = httputil.SafeHTTPClientWithTLSConfig(tlsConfig) + ersConnectRPCConn.Client = httputil.SafeHTTPClientWithTLSConfig(tlsConfig) } if cfg.SDKConfig.EntityResolutionConnection.Plaintext { tlsConfig = &tls.Config{} - ersConnectRpcConn.Client = httputil.SafeHTTPClient() + ersConnectRPCConn.Client = httputil.SafeHTTPClient() } if cfg.SDKConfig.ClientID != "" && cfg.SDKConfig.ClientSecret != "" { @@ -265,16 +265,16 @@ func Start(f ...StartOptions) error { interceptor := sdkauth.NewTokenAddingInterceptorWithClient(ts, httputil.SafeHTTPClientWithTLSConfig(tlsConfig)) - ersConnectRpcConn.Options = append(ersConnectRpcConn.Options, connect.WithInterceptors(interceptor.AddCredentialsConnect())) + ersConnectRPCConn.Options = append(ersConnectRPCConn.Options, connect.WithInterceptors(interceptor.AddCredentialsConnect())) } if sdk.IsPlatformEndpointMalformed(cfg.SDKConfig.EntityResolutionConnection.Endpoint) { return fmt.Errorf("entityresolution endpoint is malformed: %s", cfg.SDKConfig.EntityResolutionConnection.Endpoint) } - ersConnectRpcConn.Endpoint = cfg.SDKConfig.EntityResolutionConnection.Endpoint + ersConnectRPCConn.Endpoint = cfg.SDKConfig.EntityResolutionConnection.Endpoint - sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(&ersConnectRpcConn)) - logger.Info("added with custom ers connection for ", "", ersConnectRpcConn.Endpoint) + sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(&ersConnectRPCConn)) + logger.Info("added with custom ers connection for ", "", ersConnectRPCConn.Endpoint) } client, err = sdk.New("", sdkOptions...) From a8e2075f3a29f5aa24c6dd0765a8ac3824aa6961 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 12 May 2025 09:46:43 -0400 Subject: [PATCH 06/31] http token endpoint --- examples/cmd/examples.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cmd/examples.go b/examples/cmd/examples.go index d3a0db85a5..924897fa49 100644 --- a/examples/cmd/examples.go +++ b/examples/cmd/examples.go @@ -30,7 +30,7 @@ func init() { f := ExamplesCmd.PersistentFlags() f.StringVarP(&clientCredentials, "creds", "", "opentdf-sdk:secret", "client id:secret credentials") f.StringVarP(&platformEndpoint, "platformEndpoint", "e", "https://localhost:8080", "Platform Endpoint") - f.StringVarP(&tokenEndpoint, "tokenEndpoint", "t", "https://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token", "OAuth token endpoint") + f.StringVarP(&tokenEndpoint, "tokenEndpoint", "t", "http://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token", "OAuth token endpoint") f.BoolVar(&storeCollectionHeaders, "storeCollectionHeaders", false, "Store collection headers") f.BoolVar(&insecurePlaintextConn, "insecurePlaintextConn", false, "Use insecure plaintext connection") f.BoolVar(&insecureSkipVerify, "insecureSkipVerify", false, "Skip server certificate verification") From 48a4ae313bb72ff275bda103eafd7c101f00aafd Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 12 May 2025 10:01:13 -0400 Subject: [PATCH 07/31] close the listener on server close --- service/internal/server/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 43adff4093..7353f7ebf7 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -461,6 +461,10 @@ func (s OpenTDFServer) Stop() { s.logger.Error("failed to shutdown http server", slog.String("error", err.Error())) return } + // Close the listener + if s.Listener != nil { + s.Listener.Close() + } s.logger.Info("shutting down in process grpc server") if err := s.ConnectRPCInProcess.srv.Shutdown(ctx); err != nil { From 5d5029e20c378f17f3e6df3a8bfd7936df3b0766 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 12 May 2025 10:46:18 -0400 Subject: [PATCH 08/31] only execute gateway cleanup if necessary + linting --- examples/cmd/authorization.go | 9 ++++----- service/pkg/server/services.go | 14 ++++++++++---- service/pkg/server/services_test.go | 5 ++++- service/pkg/server/start.go | 4 +++- service/pkg/server/start_test.go | 4 +++- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/examples/cmd/authorization.go b/examples/cmd/authorization.go index 0186a98505..5d1f105a38 100644 --- a/examples/cmd/authorization.go +++ b/examples/cmd/authorization.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "fmt" "log/slog" "connectrpc.com/connect" @@ -61,12 +60,12 @@ func authorizationExamples() error { }) decisionRequest := &authorization.GetDecisionsRequest{DecisionRequests: drs} - slog.Info(fmt.Sprintf("Submitting decision request: %s", protojson.Format(decisionRequest))) + slog.Info("Submitting decision request: " + protojson.Format(decisionRequest)) decisionResponse, err := s.Authorization.GetDecisions(context.Background(), connect.NewRequest(decisionRequest)) if err != nil { return err } - slog.Info(fmt.Sprintf("Received decision response: %s", protojson.Format(decisionResponse.Msg))) + slog.Info("Received decision response: " + protojson.Format(decisionResponse.Msg)) // map response back to entity chain id decisionsByEntityChain := make(map[string]*authorization.DecisionResponse) @@ -74,8 +73,8 @@ func authorizationExamples() error { decisionsByEntityChain[dr.EntityChainId] = dr } - slog.Info(fmt.Sprintf("decision for bob: %s", protojson.Format(decisionsByEntityChain["ec1"]))) - slog.Info(fmt.Sprintf("decision for alice: %s", protojson.Format(decisionsByEntityChain["ec2"]))) + slog.Info("decision for bob: " + protojson.Format(decisionsByEntityChain["ec1"])) + slog.Info("decision for alice: " + protojson.Format(decisionsByEntityChain["ec2"])) return nil } diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 2686fec93c..ac0ddb30e3 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -193,11 +193,17 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF // Register GRPC Gateway Handler using the in-process connect rpc grpcConn := otdf.ConnectRPCInProcess.GrpcConn() - if err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, grpcConn); err != nil { + err := svc.RegisterGRPCGatewayHandler(ctx, otdf.GRPCGatewayMux, otdf.ConnectRPCInProcess.GrpcConn()) + if err != nil { logger.Info("service did not register a grpc gateway handler", slog.String("namespace", ns)) - } - gatewayCleanup = func() { - grpcConn.Close() + } else if gatewayCleanup == nil { + gatewayCleanup = func() { + slog.Info("executing cleanup") + if grpcConn != nil { + grpcConn.Close() + } + slog.Info("cleanup complete") + } } // Register Extra Handlers diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index fe0b046da9..1b02d084da 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -270,7 +270,10 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { "foobar": {}, }, }, otdf, nil, newLogger, registry) - defer cleanup() + if cleanup != nil { + // call cleanup function + defer cleanup() + } suite.Require().NoError(err) // require.NotNil(t, cF) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 4632ce8a27..7a814198ea 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -303,7 +303,9 @@ func Start(f ...StartOptions) error { logger.Error("issue starting services", slog.String("error", err.Error())) return fmt.Errorf("issue starting services: %w", err) } - defer gatewayCleanup() + if gatewayCleanup != nil { + defer gatewayCleanup() + } // Start watching the configuration for changes with registered config change service hooks if err := cfg.Watch(ctx); err != nil { diff --git a/service/pkg/server/start_test.go b/service/pkg/server/start_test.go index 89c58fd4f9..2bc4fc85a6 100644 --- a/service/pkg/server/start_test.go +++ b/service/pkg/server/start_test.go @@ -262,8 +262,10 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered() { "test": {}, }, }, s, nil, logger, registry) - defer cleanup() require.NoError(t, err) + if cleanup != nil { + defer cleanup() + } require.NoError(t, s.Start()) defer s.Stop() From 47e314e337d560dabd3eac599b3efb64e216950f Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 12 May 2025 11:11:38 -0400 Subject: [PATCH 09/31] better handle https in examples and bats --- examples/cmd/benchmark.go | 31 +++++++++++++++++++++++-------- examples/cmd/benchmark_bulk.go | 33 +++++++++++++++++++++++++-------- test/tdf-roundtrips.bats | 8 ++++---- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/examples/cmd/benchmark.go b/examples/cmd/benchmark.go index 96904d9da2..3ec9ce990c 100644 --- a/examples/cmd/benchmark.go +++ b/examples/cmd/benchmark.go @@ -102,7 +102,11 @@ func runBenchmark(cmd *cobra.Command, args []string) error { } nanoTDFConfig.SetAttributes(dataAttributes) nanoTDFConfig.EnableECDSAPolicyBinding() - err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + } else { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("https://%s/kas", "localhost:8080")) + } if err != nil { return err } @@ -119,16 +123,27 @@ func runBenchmark(cmd *cobra.Command, args []string) error { // } // } } else { + opts := []sdk.TDFOption{sdk.WithDataAttributes(dataAttributes...), sdk.WithAutoconfigure(false)} + if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + opts = append(opts, sdk.WithKasInformation( + sdk.KASInfo{ + URL: fmt.Sprintf("http://%s", "localhost:8080"), + PublicKey: "", + }), + ) + } else { + opts = append(opts, sdk.WithKasInformation( + sdk.KASInfo{ + URL: fmt.Sprintf("https://%s", "localhost:8080"), + PublicKey: "", + }), + ) + } tdf, err := client.CreateTDF( out, in, - sdk.WithDataAttributes(dataAttributes...), - sdk.WithKasInformation( - sdk.KASInfo{ - URL: fmt.Sprintf("http://%s", "localhost:8080"), - PublicKey: "", - }), - sdk.WithAutoconfigure(false)) + opts..., + ) if err != nil { return err } diff --git a/examples/cmd/benchmark_bulk.go b/examples/cmd/benchmark_bulk.go index cf0a648b79..b28aa92f91 100644 --- a/examples/cmd/benchmark_bulk.go +++ b/examples/cmd/benchmark_bulk.go @@ -57,7 +57,12 @@ func runBenchmarkBulk(cmd *cobra.Command, args []string) error { } nanoTDFConfig.SetAttributes(dataAttributes) nanoTDFConfig.EnableECDSAPolicyBinding() - err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + // if plaintext or platform endpoint is http, set kas url to http, otherwise https + if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) + } else { + err = nanoTDFConfig.SetKasURL(fmt.Sprintf("https://%s/kas", "localhost:8080")) + } if err != nil { return err } @@ -74,16 +79,28 @@ func runBenchmarkBulk(cmd *cobra.Command, args []string) error { } } } else { + opts := []sdk.TDFOption{sdk.WithDataAttributes(dataAttributes...), sdk.WithAutoconfigure(false)} + if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + opts = append(opts, sdk.WithKasInformation( + sdk.KASInfo{ + URL: fmt.Sprintf("http://%s", "localhost:8080"), + PublicKey: "", + }), + ) + } else { + opts = append(opts, sdk.WithKasInformation( + sdk.KASInfo{ + URL: fmt.Sprintf("https://%s", "localhost:8080"), + PublicKey: "", + }), + ) + } tdf, err := client.CreateTDF( out, in, - sdk.WithDataAttributes(dataAttributes...), - sdk.WithKasInformation( - sdk.KASInfo{ - URL: fmt.Sprintf("http://%s", "localhost:8080"), - PublicKey: "", - }), - sdk.WithAutoconfigure(false)) + + opts..., + ) if err != nil { return err } diff --git a/test/tdf-roundtrips.bats b/test/tdf-roundtrips.bats index 4181d552cc..bf1d4a679b 100755 --- a/test/tdf-roundtrips.bats +++ b/test/tdf-roundtrips.bats @@ -6,10 +6,10 @@ @test "examples: roundtrip Z-TDF" { # TODO: add subject mapping here to remove reliance on `provision fixtures` echo "[INFO] configure attribute with grant for local kas" - go run ./examples --creds opentdf:secret kas add --kas http://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" + go run ./examples --creds opentdf:secret kas add --kas https://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 -v value1 go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 - go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k http://localhost:8080 + go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k https://localhost:8080 echo "[INFO] create a tdf3 format file" run go run ./examples encrypt "Hello Zero Trust" @@ -58,11 +58,11 @@ @test "examples: roundtrip Z-TDF with extra unnecessary, invalid kas" { # TODO: add subject mapping here to remove reliance on `provision fixtures` echo "[INFO] configure attribute with grant for local kas" - go run ./examples --creds opentdf:secret kas add --kas http://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" + go run ./examples --creds opentdf:secret kas add --kas https://localhost:8080 --algorithm "rsa:2048" --kid r1 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" go run ./examples --creds opentdf:secret kas add --kas http://localhost:9090 --algorithm "rsa:2048" --kid r2 --public-key "$(<${BATS_TEST_DIRNAME}/../kas-cert.pem)" go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 -v value1 go run ./examples --creds opentdf:secret attributes unassign -a https://example.com/attr/attr1 - go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k "http://localhost:8080,http://localhost:9090" + go run ./examples --creds opentdf:secret attributes assign -a https://example.com/attr/attr1 -v value1 -k "https://localhost:8080,http://localhost:9090" echo "[INFO] create a tdf3 format file" run go run ./examples encrypt "Hello multikao split" From 3b6af9b686feebedd65083e01e466c6a3c7d76e6 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 12 May 2025 12:46:29 -0400 Subject: [PATCH 10/31] strip the kas path when making connect rpc client --- examples/cmd/benchmark.go | 4 +-- examples/cmd/benchmark_bulk.go | 4 +-- sdk/kas_client.go | 30 +++++++++++----------- sdk/kas_client_test.go | 46 ++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 20 deletions(-) diff --git a/examples/cmd/benchmark.go b/examples/cmd/benchmark.go index 3ec9ce990c..1d648d9aa5 100644 --- a/examples/cmd/benchmark.go +++ b/examples/cmd/benchmark.go @@ -127,14 +127,14 @@ func runBenchmark(cmd *cobra.Command, args []string) error { if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ - URL: fmt.Sprintf("http://%s", "localhost:8080"), + URL: "http://localhost:8080", PublicKey: "", }), ) } else { opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ - URL: fmt.Sprintf("https://%s", "localhost:8080"), + URL: "https://localhost:8080", PublicKey: "", }), ) diff --git a/examples/cmd/benchmark_bulk.go b/examples/cmd/benchmark_bulk.go index b28aa92f91..0795c3167b 100644 --- a/examples/cmd/benchmark_bulk.go +++ b/examples/cmd/benchmark_bulk.go @@ -83,14 +83,14 @@ func runBenchmarkBulk(cmd *cobra.Command, args []string) error { if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ - URL: fmt.Sprintf("http://%s", "localhost:8080"), + URL: "http://localhost:8080", PublicKey: "", }), ) } else { opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ - URL: fmt.Sprintf("https://%s", "localhost:8080"), + URL: "https://localhost:8080", PublicKey: "", }), ) diff --git a/sdk/kas_client.go b/sdk/kas_client.go index 18aabd2fed..8df0bcb597 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -65,12 +65,12 @@ func (k *KASClient) makeRewrapRequest(ctx context.Context, requests []*kas.Unsig return nil, err } kasURL := requests[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetKasUrl() - _, err = url.Parse(kasURL) + parsedUrl, err := parseBaseUrl(kasURL) if err != nil { return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasURL, err) } - serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, kasURL, k.connectOptions...) + serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, parsedUrl, k.connectOptions...) response, err := serviceClient.Rewrap(ctx, connect.NewRequest(rewrapRequest)) if err != nil { @@ -312,24 +312,22 @@ func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecrypt return policyResults, nil } -func getGRPCAddress(kasURL string) (string, error) { - parsedURL, err := url.Parse(kasURL) +func parseBaseUrl(rawURL string) (string, error) { + u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("cannot parse kas url(%s): %w", kasURL, err) + return "", err } - // Needed to support buffconn for testing - if parsedURL.Host == "" && parsedURL.Port() == "" { - return "", nil - } + host := u.Hostname() + port := u.Port() - port := parsedURL.Port() - // if port is empty, default to 443. - if port == "" { - port = "443" + // Add port only if it's present + addr := host + if port != "" { + addr = net.JoinHostPort(host, port) } - return net.JoinHostPort(parsedURL.Hostname(), port), nil + return fmt.Sprintf("%s://%s", u.Scheme, addr), nil } func (k *KASClient) getRewrapRequest(reqs []*kas.UnsignedRewrapRequest_WithPolicyRequest, pubKey string) (*kas.RewrapRequest, error) { @@ -427,12 +425,12 @@ func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm string) (*KASIn return cachedValue, nil } } - _, err := url.Parse(kasurl) + parsedUrl, err := parseBaseUrl(kasurl) if err != nil { return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasurl, err) } - serviceClient := kasconnect.NewAccessServiceClient(s.conn.Client, kasurl, s.conn.Options...) + serviceClient := kasconnect.NewAccessServiceClient(s.conn.Client, parsedUrl, s.conn.Options...) req := kas.PublicKeyRequest{ Algorithm: algorithm, diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 0750bbd9c5..6a0c9ce6cd 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -214,3 +214,49 @@ func (suite *TestUpgradeRewrapRequestV1Suite) TestUpgradeRewrapRequestV1_Empty() func TestUpgradeRewrapRequestV1(t *testing.T) { suite.Run(t, new(TestUpgradeRewrapRequestV1Suite)) } + +func TestParseBaseUrl(t *testing.T) { + tests := []struct { + name string + input string + expected string + expectError bool + }{ + { + name: "Valid URL with scheme and port", + input: "https://example.com:8080/path", + expected: "https://example.com:8080", + expectError: false, + }, + { + name: "Valid URL with scheme and no port", + input: "https://example.com/path", + expected: "https://example.com", + expectError: false, + }, + { + name: "Valid URL with default port", + input: "http://example.com", + expected: "http://example.com", + expectError: false, + }, + { + name: "Invalid URL with invalid characters", + input: "https://exa mple.com", + expected: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseBaseUrl(tt.input) + if tt.expectError { + assert.Error(t, err, "Expected an error for test case: %s", tt.name) + } else { + assert.NoError(t, err, "Did not expect an error for test case: %s", tt.name) + assert.Equal(t, tt.expected, result, "Unexpected result for test case: %s", tt.name) + } + }) + } +} From 0ca057dead56e8c58e2ebb0b1c647a8307ead68e Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 12 May 2025 13:00:02 -0400 Subject: [PATCH 11/31] linting, fix scheme rt tests --- sdk/kas_client.go | 10 +++++----- sdk/kas_client_test.go | 6 +++--- service/rttests/rt_test.go | 17 ++++++++++------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/sdk/kas_client.go b/sdk/kas_client.go index 8df0bcb597..93bfdf1a27 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -65,12 +65,12 @@ func (k *KASClient) makeRewrapRequest(ctx context.Context, requests []*kas.Unsig return nil, err } kasURL := requests[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetKasUrl() - parsedUrl, err := parseBaseUrl(kasURL) + parsedURL, err := parseBaseURL(kasURL) if err != nil { return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasURL, err) } - serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, parsedUrl, k.connectOptions...) + serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, parsedURL, k.connectOptions...) response, err := serviceClient.Rewrap(ctx, connect.NewRequest(rewrapRequest)) if err != nil { @@ -312,7 +312,7 @@ func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecrypt return policyResults, nil } -func parseBaseUrl(rawURL string) (string, error) { +func parseBaseURL(rawURL string) (string, error) { u, err := url.Parse(rawURL) if err != nil { return "", err @@ -425,12 +425,12 @@ func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm string) (*KASIn return cachedValue, nil } } - parsedUrl, err := parseBaseUrl(kasurl) + parsedURL, err := parseBaseURL(kasurl) if err != nil { return nil, fmt.Errorf("cannot parse kas url(%s): %w", kasurl, err) } - serviceClient := kasconnect.NewAccessServiceClient(s.conn.Client, parsedUrl, s.conn.Options...) + serviceClient := kasconnect.NewAccessServiceClient(s.conn.Client, parsedURL, s.conn.Options...) req := kas.PublicKeyRequest{ Algorithm: algorithm, diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 6a0c9ce6cd..d15d785755 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -250,11 +250,11 @@ func TestParseBaseUrl(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := parseBaseUrl(tt.input) + result, err := parseBaseURL(tt.input) if tt.expectError { - assert.Error(t, err, "Expected an error for test case: %s", tt.name) + require.Error(t, err, "Expected an error for test case: %s", tt.name) } else { - assert.NoError(t, err, "Did not expect an error for test case: %s", tt.name) + require.NoError(t, err, "Did not expect an error for test case: %s", tt.name) assert.Equal(t, tt.expected, result, "Unexpected result for test case: %s", tt.name) } }) diff --git a/service/rttests/rt_test.go b/service/rttests/rt_test.go index f602b75548..ca45939226 100644 --- a/service/rttests/rt_test.go +++ b/service/rttests/rt_test.go @@ -29,10 +29,11 @@ import ( // then those will need to be updated. type TestConfig struct { - PlatformEndpoint string - TokenEndpoint string - ClientID string - ClientSecret string + PlatformEndpoint string + PlatformEndpointWithScheme string + TokenEndpoint string + ClientID string + ClientSecret string } var attributesToMap = []string{ @@ -112,11 +113,14 @@ func (s *RoundtripSuite) SetupSuite() { opts := []sdk.Option{} if os.Getenv("TLS_ENABLED") == "" { opts = append(opts, sdk.WithInsecurePlaintextConn()) + s.TestConfig.PlatformEndpointWithScheme = "http://" + s.TestConfig.PlatformEndpoint + } else { + s.TestConfig.PlatformEndpointWithScheme = "https://" + s.TestConfig.PlatformEndpoint } opts = append(opts, sdk.WithClientCredentials(s.TestConfig.ClientID, s.TestConfig.ClientSecret, nil)) - sdk, err := sdk.New("http://"+s.TestConfig.PlatformEndpoint, opts...) + sdk, err := sdk.New(s.TestConfig.PlatformEndpointWithScheme, opts...) s.Require().NoError(err) s.client = sdk @@ -343,8 +347,7 @@ func encrypt(client *sdk.SDK, testConfig TestConfig, plaintext string, attribute sdk.WithDataAttributes(attributes...), sdk.WithKasInformation( sdk.KASInfo{ - // examples assume insecure http - URL: "http://" + testConfig.PlatformEndpoint, + URL: testConfig.PlatformEndpointWithScheme, PublicKey: "", })) if err != nil { From 3eeda9887251716b28a554d839219429427d739a Mon Sep 17 00:00:00 2001 From: Elizabeth Healy <35498075+elizabethhealy@users.noreply.github.com> Date: Thu, 15 May 2025 01:00:55 -0400 Subject: [PATCH 12/31] feat(sdk): Hide connect rpc client side (#2206) ### Proposed Changes * generate connect wrappers for the grpc client interfaces, allows the sdk object to still have grpc client interfaces (meaning if someone is running a seperate ERS for example, it doesnt have to be connectrpc, same for the other services) ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions --- .github/workflows/checks.yaml | 5 +- Makefile | 7 +- examples/cmd/attributes.go | 87 ++++---- examples/cmd/authorization.go | 7 +- examples/cmd/benchmark_decision.go | 7 +- examples/cmd/kas.go | 25 ++- examples/go.mod | 1 - examples/go.sum | 2 - sdk/go.mod | 11 +- sdk/go.sum | 26 ++- sdk/granter.go | 10 +- sdk/granter_test.go | 5 +- sdk/internal/codegen/main.go | 13 ++ sdk/internal/codegen/runner/generate.go | 216 ++++++++++++++++++++ sdk/sdk.go | 71 +++---- sdk/sdkconnect/actions.go | 63 ++++++ sdk/sdkconnect/attributes.go | 189 +++++++++++++++++ sdk/sdkconnect/authorization.go | 45 ++++ sdk/sdkconnect/entityresolution.go | 36 ++++ sdk/sdkconnect/kasregistry.go | 117 +++++++++++ sdk/sdkconnect/keymanagement.go | 63 ++++++ sdk/sdkconnect/namespaces.go | 99 +++++++++ sdk/sdkconnect/registeredresources.go | 117 +++++++++++ sdk/sdkconnect/resourcemapping.go | 117 +++++++++++ sdk/sdkconnect/subjectmapping.go | 126 ++++++++++++ sdk/sdkconnect/unsafe.go | 108 ++++++++++ sdk/sdkconnect/wellknownconfiguration.go | 27 +++ sdk/tdf.go | 7 +- service/authorization/authorization.go | 30 +-- service/authorization/authorization_test.go | 118 ++++++----- service/kas/access/accessPdp.go | 5 +- service/rttests/rt_test.go | 43 ++-- 32 files changed, 1581 insertions(+), 222 deletions(-) create mode 100644 sdk/internal/codegen/main.go create mode 100644 sdk/internal/codegen/runner/generate.go create mode 100644 sdk/sdkconnect/actions.go create mode 100644 sdk/sdkconnect/attributes.go create mode 100644 sdk/sdkconnect/authorization.go create mode 100644 sdk/sdkconnect/entityresolution.go create mode 100644 sdk/sdkconnect/kasregistry.go create mode 100644 sdk/sdkconnect/keymanagement.go create mode 100644 sdk/sdkconnect/namespaces.go create mode 100644 sdk/sdkconnect/registeredresources.go create mode 100644 sdk/sdkconnect/resourcemapping.go create mode 100644 sdk/sdkconnect/subjectmapping.go create mode 100644 sdk/sdkconnect/unsafe.go create mode 100644 sdk/sdkconnect/wellknownconfiguration.go diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index ee14decb0e..3e7ffce328 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -346,6 +346,7 @@ jobs: focus-sdk: go # use commit instead of ref so we can "go get" specific sdk version platform-ref: ${{ github.event.pull_request.head.sha || github.sha }} lts + otdfctl-ref: 107e016c326564234757a55e55086fdf66e83078 # test latest otdfctl CLI 'main' against platform PR branch otdfctl-test: @@ -392,6 +393,8 @@ jobs: - run: cd service && go get github.com/pseudomuto/protoc-gen-doc/cmd/protoc-gen-doc - run: cd service && go install github.com/pseudomuto/protoc-gen-doc/cmd/protoc-gen-doc - run: make proto-generate + - name: generate connect wrappers + run: make connect-wrapper-generate - name: Restore go.mod after installing protoc-gen-doc run: git restore {service,protocol/go}/go.{mod,sum} - name: validate go mod tidy @@ -401,7 +404,7 @@ jobs: git restore go.sum - run: git diff - run: git diff-files --ignore-submodules - - name: Check that make proto-generate has run before PR submission; see above for error details + - name: Check that make proto-generate and connect-wrapper-generate have run before PR submission; see above for error details run: git diff-files --quiet --ignore-submodules ci: diff --git a/Makefile b/Makefile index 2c316db529..c31aa3fa3b 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # make # To run all lint checks: `LINT_OPTIONS= make lint` -.PHONY: all build clean docker-build fix fmt go-lint license lint proto-generate proto-lint sdk/sdk test tidy toolcheck +.PHONY: all build clean docker-build fix fmt go-lint license lint proto-generate connect-wrapper-generate proto-lint sdk/sdk test tidy toolcheck MODS=protocol/go lib/ocrypto lib/fixtures lib/flattening lib/identifier sdk service examples HAND_MODS=lib/ocrypto lib/fixtures lib/flattening lib/identifier sdk service examples @@ -71,6 +71,9 @@ proto-generate: buf generate buf.build/grpc-ecosystem/grpc-gateway -o tmp-gen --template buf.gen.grpc.docs.yaml buf generate buf.build/grpc-ecosystem/grpc-gateway -o tmp-gen --template buf.gen.openapi.docs.yaml +connect-wrapper-generate: + go run ./sdk/internal/codegen + policy-sql-gen: @which sqlc > /dev/null || { echo "sqlc not found, please install it: https://docs.sqlc.dev/en/stable/overview/install.html"; exit 1; } sqlc generate -f service/policy/db/sqlc.yaml @@ -95,7 +98,7 @@ clean: for m in $(MODS); do (cd $$m && go clean) || exit 1; done rm -f opentdf examples/examples -build: proto-generate opentdf sdk/sdk examples/examples +build: proto-generate connect-wrapper-generate opentdf sdk/sdk examples/examples opentdf: $(shell find service) go build -o opentdf -v service/main.go diff --git a/examples/cmd/attributes.go b/examples/cmd/attributes.go index ec258bbcb7..b33d08fc74 100644 --- a/examples/cmd/attributes.go +++ b/examples/cmd/attributes.go @@ -9,7 +9,6 @@ import ( "regexp" "strings" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" @@ -112,12 +111,12 @@ func listAttributes(cmd *cobra.Command) error { var nsuris []string if ns == "" { slog.Info("listing namespaces") - listResp, err := s.Namespaces.ListNamespaces(ctx, connect.NewRequest(&namespaces.ListNamespacesRequest{})) + listResp, err := s.Namespaces.ListNamespaces(ctx, &namespaces.ListNamespacesRequest{}) if err != nil { return err } - slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Msg.GetNamespaces()))) - for _, n := range listResp.Msg.GetNamespaces() { + slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.GetNamespaces()))) + for _, n := range listResp.GetNamespaces() { nsuris = append(nsuris, n.GetFqn()) } } else { @@ -128,15 +127,15 @@ func listAttributes(cmd *cobra.Command) error { if err != nil { return err } - lsr, err := s.Attributes.ListAttributes(ctx, connect.NewRequest(&attributes.ListAttributesRequest{ + lsr, err := s.Attributes.ListAttributes(ctx, &attributes.ListAttributesRequest{ // namespace here must be the namespace name Namespace: u.Host, - })) + }) if err != nil { return err } - slog.Info(fmt.Sprintf("found %d attributes in namespace", len(lsr.Msg.GetAttributes())), "ns", n) - for _, a := range lsr.Msg.GetAttributes() { + slog.Info(fmt.Sprintf("found %d attributes in namespace", len(lsr.GetAttributes())), "ns", n) + for _, a := range lsr.GetAttributes() { if longformat { fmt.Printf("%s\t%s\n", a.GetFqn(), a.GetId()) } else { @@ -160,12 +159,12 @@ func nsuuid(ctx context.Context, s *sdk.SDK, u string) (string, error) { slog.Error("namespace url.Parse", "err", err, "url", u) return "", errors.Join(err, ErrInvalidArgument) } - listResp, err := s.Namespaces.ListNamespaces(ctx, connect.NewRequest(&namespaces.ListNamespacesRequest{})) + listResp, err := s.Namespaces.ListNamespaces(ctx, &namespaces.ListNamespacesRequest{}) if err != nil { slog.Error("ListNamespaces", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - for _, n := range listResp.Msg.GetNamespaces() { + for _, n := range listResp.GetNamespaces() { if n.GetName() == url.Hostname() { return n.GetId(), nil } @@ -174,15 +173,15 @@ func nsuuid(ctx context.Context, s *sdk.SDK, u string) (string, error) { } func attruuid(ctx context.Context, s *sdk.SDK, nsu, fqn string) (string, error) { - resp, err := s.Attributes.ListAttributes(ctx, connect.NewRequest(&attributes.ListAttributesRequest{ + resp, err := s.Attributes.ListAttributes(ctx, &attributes.ListAttributesRequest{ Namespace: nsu, State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ANY, - })) + }) if err != nil { slog.Error("ListAttributes", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - for _, a := range resp.Msg.GetAttributes() { + for _, a := range resp.GetAttributes() { if strings.ToLower(a.GetFqn()) == strings.ToLower(fqn) { return a.GetId(), nil } @@ -191,12 +190,12 @@ func attruuid(ctx context.Context, s *sdk.SDK, nsu, fqn string) (string, error) } func avuuid(ctx context.Context, s *sdk.SDK, auuid, vs string) (string, error) { - resp, err := s.Attributes.GetAttribute(ctx, connect.NewRequest(&attributes.GetAttributeRequest{Id: auuid})) + resp, err := s.Attributes.GetAttribute(ctx, &attributes.GetAttributeRequest{Id: auuid}) if err != nil { slog.Error("GetAttribute", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - for _, v := range resp.Msg.GetAttribute().GetValues() { + for _, v := range resp.GetAttribute().GetValues() { if strings.ToLower(v.GetValue()) == strings.ToLower(vs) { return v.GetId(), nil } @@ -210,12 +209,12 @@ func addNamespace(ctx context.Context, s *sdk.SDK, u string) (string, error) { slog.Error("url.Parse", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - resp, err := s.Namespaces.CreateNamespace(ctx, connect.NewRequest(&namespaces.CreateNamespaceRequest{Name: url.Hostname()})) + resp, err := s.Namespaces.CreateNamespace(ctx, &namespaces.CreateNamespaceRequest{Name: url.Hostname()}) if err != nil { slog.Error("CreateNamespace", "err", err) return "", errors.Join(err, ErrInvalidArgument) } - return resp.Msg.GetNamespace().GetId(), nil + return resp.GetNamespace().GetId(), nil } func addAttribute(cmd *cobra.Command) error { @@ -275,10 +274,10 @@ func removeAttribute(cmd *cobra.Command) error { } if len(values) == 0 { if unsafeBool { - resp, err := s.Unsafe.UnsafeDeleteAttribute(cmd.Context(), connect.NewRequest(&unsafe.UnsafeDeleteAttributeRequest{ + resp, err := s.Unsafe.UnsafeDeleteAttribute(cmd.Context(), &unsafe.UnsafeDeleteAttributeRequest{ Id: auuid, Fqn: strings.ToLower(attr), - })) + }) if err != nil { slog.Error("UnsafeDeleteAttribute", "err", err, "id", auuid) return err @@ -286,9 +285,9 @@ func removeAttribute(cmd *cobra.Command) error { slog.Info("deleted attribute", "attr", attr, "resp", resp) return nil } - resp, err := s.Attributes.DeactivateAttribute(cmd.Context(), connect.NewRequest(&attributes.DeactivateAttributeRequest{ + resp, err := s.Attributes.DeactivateAttribute(cmd.Context(), &attributes.DeactivateAttributeRequest{ Id: auuid, - })) + }) if err != nil { slog.Error("DeactivateAttribute", "err", err, "id", auuid) return err @@ -302,19 +301,19 @@ func removeAttribute(cmd *cobra.Command) error { return err } if unsafeBool { - r, err := s.Unsafe.UnsafeDeleteAttributeValue(cmd.Context(), connect.NewRequest(&unsafe.UnsafeDeleteAttributeValueRequest{ + r, err := s.Unsafe.UnsafeDeleteAttributeValue(cmd.Context(), &unsafe.UnsafeDeleteAttributeValueRequest{ Id: avu, Fqn: strings.ToLower(attr + "/value/" + url.PathEscape(v)), - })) + }) if err != nil { slog.Error("UnsafeDeleteAttributeValue", "err", err, "id", avu) return err } slog.Info("deactivated attribute value", "attr", attr, "value", v, "resp", r) } else { - r, err := s.Attributes.DeactivateAttributeValue(cmd.Context(), connect.NewRequest(&attributes.DeactivateAttributeValueRequest{ + r, err := s.Attributes.DeactivateAttributeValue(cmd.Context(), &attributes.DeactivateAttributeValueRequest{ Id: avu, - })) + }) if err != nil { slog.Error("DeactivateAttributeValue", "err", err, "id", avu) return err @@ -369,11 +368,11 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { return fmt.Errorf("assign must take a `--kas` parameter") case len(values) == 0: // look up all kasids associated with the attribute - ar, err := s.Attributes.GetAttribute(cmd.Context(), connect.NewRequest(&attributes.GetAttributeRequest{Id: auuid})) + ar, err := s.Attributes.GetAttribute(cmd.Context(), &attributes.GetAttributeRequest{Id: auuid}) if err != nil { return err } - for _, b := range ar.Msg.GetAttribute().GetGrants() { + for _, b := range ar.GetAttribute().GetGrants() { kasids = append(kasids, b.GetId()) kasById[b.GetId()] = b.GetUri() } @@ -385,11 +384,11 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { if err != nil { return err } - ar, err := s.Attributes.GetAttributeValue(cmd.Context(), connect.NewRequest(&attributes.GetAttributeValueRequest{Id: avu})) + ar, err := s.Attributes.GetAttributeValue(cmd.Context(), &attributes.GetAttributeValueRequest{Id: avu}) if err != nil { return err } - for _, b := range ar.Msg.GetValue().GetGrants() { + for _, b := range ar.GetValue().GetGrants() { kasids = append(kasids, b.GetId()) kasById[b.GetId()] = b.GetUri() } @@ -398,27 +397,27 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { for _, kasid := range kasids { if len(values) == 0 { if assign { - r, err := s.Attributes.AssignKeyAccessServerToAttribute(cmd.Context(), connect.NewRequest(&attributes.AssignKeyAccessServerToAttributeRequest{ + r, err := s.Attributes.AssignKeyAccessServerToAttribute(cmd.Context(), &attributes.AssignKeyAccessServerToAttributeRequest{ AttributeKeyAccessServer: &attributes.AttributeKeyAccessServer{ AttributeId: auuid, KeyAccessServerId: kasid, }, - })) + }) if err != nil { return err } - cmd.Printf("successfully assigned all of [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.Msg.GetAttributeKeyAccessServer()) + cmd.Printf("successfully assigned all of [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.GetAttributeKeyAccessServer()) } else { - r, err := s.Attributes.RemoveKeyAccessServerFromAttribute(cmd.Context(), connect.NewRequest(&attributes.RemoveKeyAccessServerFromAttributeRequest{ + r, err := s.Attributes.RemoveKeyAccessServerFromAttribute(cmd.Context(), &attributes.RemoveKeyAccessServerFromAttributeRequest{ AttributeKeyAccessServer: &attributes.AttributeKeyAccessServer{ AttributeId: auuid, KeyAccessServerId: kasid, }, - })) + }) if err != nil { return err } - cmd.Printf("successfully unassigned [%s] from [%s] (binding %v)\n", attr, kasById[kasid], *r.Msg.GetAttributeKeyAccessServer()) + cmd.Printf("successfully unassigned [%s] from [%s] (binding %v)\n", attr, kasById[kasid], *r.GetAttributeKeyAccessServer()) } } else { for _, v := range values { @@ -427,27 +426,27 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { return err } if assign { - r, err := s.Attributes.AssignKeyAccessServerToValue(cmd.Context(), connect.NewRequest(&attributes.AssignKeyAccessServerToValueRequest{ + r, err := s.Attributes.AssignKeyAccessServerToValue(cmd.Context(), &attributes.AssignKeyAccessServerToValueRequest{ ValueKeyAccessServer: &attributes.ValueKeyAccessServer{ ValueId: avu, KeyAccessServerId: kasid, }, - })) + }) if err != nil { return err } - cmd.Printf("successfully assigned [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.Msg.GetValueKeyAccessServer()) + cmd.Printf("successfully assigned [%s] to [%s] (binding [%v])\n", attr, kasById[kasid], *r.GetValueKeyAccessServer()) } else { - r, err := s.Attributes.RemoveKeyAccessServerFromValue(cmd.Context(), connect.NewRequest(&attributes.RemoveKeyAccessServerFromValueRequest{ + r, err := s.Attributes.RemoveKeyAccessServerFromValue(cmd.Context(), &attributes.RemoveKeyAccessServerFromValueRequest{ ValueKeyAccessServer: &attributes.ValueKeyAccessServer{ ValueId: avu, KeyAccessServerId: kasid, }, - })) + }) if err != nil { return err } - cmd.Printf("successfully unassigned [%s] from [%s] (binding [%v])\n", attr, kasById[kasid], *r.Msg.GetValueKeyAccessServer()) + cmd.Printf("successfully unassigned [%s] from [%s] (binding [%v])\n", attr, kasById[kasid], *r.GetValueKeyAccessServer()) } } } @@ -470,15 +469,15 @@ func ruler() policy.AttributeRuleTypeEnum { func upsertAttr(ctx context.Context, s *sdk.SDK, auth, name string, values []string) (string, error) { av, err := - s.Attributes.CreateAttribute(ctx, connect.NewRequest(&attributes.CreateAttributeRequest{ + s.Attributes.CreateAttribute(ctx, &attributes.CreateAttributeRequest{ NamespaceId: auth, Name: name, Rule: ruler(), Values: values, - })) + }) if err != nil { slog.Error("CreateAttribute", "err", err, "auth", auth, "name", name, "values", values, "rule", ruler()) return "", err } - return av.Msg.GetAttribute().GetId(), nil + return av.GetAttribute().GetId(), nil } diff --git a/examples/cmd/authorization.go b/examples/cmd/authorization.go index 5d1f105a38..bfd9d28375 100644 --- a/examples/cmd/authorization.go +++ b/examples/cmd/authorization.go @@ -4,7 +4,6 @@ import ( "context" "log/slog" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/sdk" @@ -61,15 +60,15 @@ func authorizationExamples() error { decisionRequest := &authorization.GetDecisionsRequest{DecisionRequests: drs} slog.Info("Submitting decision request: " + protojson.Format(decisionRequest)) - decisionResponse, err := s.Authorization.GetDecisions(context.Background(), connect.NewRequest(decisionRequest)) + decisionResponse, err := s.Authorization.GetDecisions(context.Background(), decisionRequest) if err != nil { return err } - slog.Info("Received decision response: " + protojson.Format(decisionResponse.Msg)) + slog.Info("Received decision response: " + protojson.Format(decisionResponse)) // map response back to entity chain id decisionsByEntityChain := make(map[string]*authorization.DecisionResponse) - for _, dr := range decisionResponse.Msg.GetDecisionResponses() { + for _, dr := range decisionResponse.GetDecisionResponses() { decisionsByEntityChain[dr.EntityChainId] = dr } diff --git a/examples/cmd/benchmark_decision.go b/examples/cmd/benchmark_decision.go index ea5a66f8cc..d34998bc23 100644 --- a/examples/cmd/benchmark_decision.go +++ b/examples/cmd/benchmark_decision.go @@ -5,7 +5,6 @@ import ( "fmt" "time" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/policy" "github.com/spf13/cobra" @@ -35,7 +34,7 @@ func runDecisionBenchmark(cmd *cobra.Command, args []string) error { } start := time.Now() - res, err := client.Authorization.GetDecisions(context.Background(), connect.NewRequest(&authorization.GetDecisionsRequest{ + res, err := client.Authorization.GetDecisions(context.Background(), &authorization.GetDecisionsRequest{ DecisionRequests: []*authorization.DecisionRequest{ { Actions: []*policy.Action{{Value: &policy.Action_Standard{ @@ -49,14 +48,14 @@ func runDecisionBenchmark(cmd *cobra.Command, args []string) error { ResourceAttributes: ras, }, }, - })) + }) end := time.Now() totalTime := end.Sub(start) numberApproved := 0 numberDenied := 0 if err == nil { - for _, dr := range res.Msg.GetDecisionResponses() { + for _, dr := range res.GetDecisionResponses() { if dr.Decision == authorization.DecisionResponse_DECISION_PERMIT { numberApproved += 1 } else { diff --git a/examples/cmd/kas.go b/examples/cmd/kas.go index 983b0d19f5..cb58c54d66 100644 --- a/examples/cmd/kas.go +++ b/examples/cmd/kas.go @@ -6,7 +6,6 @@ import ( "log/slog" "strings" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/kasregistry" "github.com/opentdf/platform/sdk" @@ -71,7 +70,7 @@ func listKases(cmd *cobra.Command) error { return err } - r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) + r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), &kasregistry.ListKeyAccessServersRequest{}) if err != nil { slog.Error("ListKeyAccessServers", "error", err) return err @@ -79,12 +78,12 @@ func listKases(cmd *cobra.Command) error { slog.Info("listing kas registry") - if len(r.Msg.GetKeyAccessServers()) == 0 { + if len(r.GetKeyAccessServers()) == 0 { cmd.Println("no key access servers registered") return nil } - for _, k := range r.Msg.GetKeyAccessServers() { + for _, k := range r.GetKeyAccessServers() { if longformat { fmt.Printf("%s\t%s\t%s\n", k.GetUri(), k.GetId(), k.GetPublicKey()) } else { @@ -95,12 +94,12 @@ func listKases(cmd *cobra.Command) error { } func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *policy.PublicKey) (string, error) { - r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(ctx, connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) + r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) if err != nil { slog.Error("ListKeyAccessServers", "err", err) return "", err } - for _, ki := range r.Msg.GetKeyAccessServers() { + for _, ki := range r.GetKeyAccessServers() { if strings.ToLower(uri) == strings.ToLower(ki.GetUri()) { oldpk := ki.GetPublicKey() recreate := false @@ -114,7 +113,7 @@ func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *poli if !recreate { return ki.GetId(), nil } - _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(ctx, connect.NewRequest(&kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()})) + _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(ctx, &kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()}) if err != nil { slog.Error("DeleteKeyAccessServer", "err", err) return "", err @@ -130,15 +129,15 @@ func upsertKasRegistration(ctx context.Context, s *sdk.SDK, uri string, pk *poli Remote: uri + "/v2/kas_public_key", } } - ur, err := s.KeyAccessServerRegistry.CreateKeyAccessServer(ctx, connect.NewRequest(&kasregistry.CreateKeyAccessServerRequest{ + ur, err := s.KeyAccessServerRegistry.CreateKeyAccessServer(ctx, &kasregistry.CreateKeyAccessServerRequest{ Uri: uri, PublicKey: pk, - })) + }) if err != nil { slog.Error("CreateKeyAccessServer", "uri", uri, "publicKey", uri+"/v2/kas_public_key") return "", err } - return ur.Msg.GetKeyAccessServer().GetId(), nil + return ur.GetKeyAccessServer().GetId(), nil } func algString2Proto(a string) policy.KasPublicKeyAlgEnum { @@ -206,15 +205,15 @@ func removeKas(cmd *cobra.Command) error { return err } - r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) + r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), &kasregistry.ListKeyAccessServersRequest{}) if err != nil { slog.Error("ListKeyAccessServers", "err", err) return err } deletedSomething := false - for _, ki := range r.Msg.GetKeyAccessServers() { + for _, ki := range r.GetKeyAccessServers() { if strings.ToLower(kas) == strings.ToLower(ki.GetUri()) { - _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(cmd.Context(), connect.NewRequest(&kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()})) + _, err := s.KeyAccessServerRegistry.DeleteKeyAccessServer(cmd.Context(), &kasregistry.DeleteKeyAccessServerRequest{Id: ki.GetId()}) if err != nil { slog.Error("DeleteKeyAccessServer", "err", err) return err diff --git a/examples/go.mod b/examples/go.mod index 04b308f3d2..9d9c6e01b8 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -5,7 +5,6 @@ go 1.24.0 toolchain go1.24.2 require ( - connectrpc.com/connect v1.18.1 github.com/opentdf/platform/lib/ocrypto v0.1.9 github.com/opentdf/platform/protocol/go v0.3.2 github.com/opentdf/platform/sdk v0.4.4 diff --git a/examples/go.sum b/examples/go.sum index ca8720a84f..07aa828837 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,7 +1,5 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 h1:LEXWFH/xZ5oOWrC3oOtHbUyBdzRWMCPpAQmKC9v05mA= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1/go.mod h1:XF+P8+RmfdufmIYpGUC+6bF7S+IlmHDEnCrO3OXaUAQ= -connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= -connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= diff --git a/sdk/go.mod b/sdk/go.mod index ab6e3f842f..31f175f275 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -18,6 +18,7 @@ require ( github.com/testcontainers/testcontainers-go v0.34.0 github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/oauth2 v0.26.0 + golang.org/x/tools v0.33.0 google.golang.org/grpc v1.71.0 google.golang.org/protobuf v1.36.6 ) @@ -86,10 +87,12 @@ require ( go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - golang.org/x/crypto v0.36.0 // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/sdk/go.sum b/sdk/go.sum index d3aafe3b5e..91b28a1e07 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -200,12 +200,14 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -216,8 +218,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -225,6 +227,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -242,24 +246,24 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= @@ -269,6 +273,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/sdk/granter.go b/sdk/granter.go index 67fbfdcd4a..5e0f32c5ab 100644 --- a/sdk/granter.go +++ b/sdk/granter.go @@ -10,10 +10,8 @@ import ( "sort" "strings" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" ) var ErrInvalid = errors.New("invalid type") @@ -223,18 +221,18 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant { } // Gets a list of directory of KAS grants for a list of attribute FQNs -func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attributesconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { +func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attributes.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { fqnsStr := make([]string, len(fqns)) for i, v := range fqns { fqnsStr[i] = v.String() } - av, err := as.GetAttributeValuesByFqns(ctx, connect.NewRequest(&attributes.GetAttributeValuesByFqnsRequest{ + av, err := as.GetAttributeValuesByFqns(ctx, &attributes.GetAttributeValuesByFqnsRequest{ Fqns: fqnsStr, WithValue: &policy.AttributeValueSelector{ WithKeyAccessGrants: true, }, - })) + }) if err != nil { return granter{}, err } @@ -243,7 +241,7 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attrib policy: fqns, grants: make(map[string]*keyAccessGrant), } - for fqnstr, pair := range av.Msg.GetFqnAttributeValues() { + for fqnstr, pair := range av.GetFqnAttributeValues() { fqn, err := NewAttributeValueFQN(fqnstr) if err != nil { return grants, err diff --git a/sdk/granter_test.go b/sdk/granter_test.go index 5cacbb3abb..3ee33eb213 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -13,6 +13,7 @@ import ( "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + "github.com/opentdf/platform/sdk/sdkconnect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -612,7 +613,7 @@ func TestReasonerSpecificity(t *testing.T) { }, } { t.Run(tc.n, func(t *testing.T) { - reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) + reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &sdkconnect.AttributesServiceClientConnectWrapper{AttributesServiceClient: &mockAttributesClient{}}, tc.policy...) require.NoError(t, err) i := 0 plan, err := reasoner.plan(tc.defaults, func() string { @@ -763,7 +764,7 @@ func TestReasonerSpecificityWithNamespaces(t *testing.T) { }, } { t.Run((tc.n + "\n" + tc.desc), func(t *testing.T) { - reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) + reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &sdkconnect.AttributesServiceClientConnectWrapper{AttributesServiceClient: &mockAttributesClient{}}, tc.policy...) require.NoError(t, err) i := 0 plan, err := reasoner.plan(tc.defaults, func() string { diff --git a/sdk/internal/codegen/main.go b/sdk/internal/codegen/main.go new file mode 100644 index 0000000000..e8fb00c0e0 --- /dev/null +++ b/sdk/internal/codegen/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "log" + + "github.com/opentdf/platform/sdk/internal/codegen/runner" +) + +func main() { + if err := runner.Generate(); err != nil { + log.Fatal(err) + } +} diff --git a/sdk/internal/codegen/runner/generate.go b/sdk/internal/codegen/runner/generate.go new file mode 100644 index 0000000000..e40a593dc3 --- /dev/null +++ b/sdk/internal/codegen/runner/generate.go @@ -0,0 +1,216 @@ +package runner + +import ( + "errors" + "fmt" + "go/ast" + "log/slog" + "os" + "path" + "path/filepath" + "runtime" + + "golang.org/x/tools/go/packages" +) + +type clientsToGenerate struct { + grpcClientInterface string + grpcPackagePath string +} + +var clientsToGenerateList = []clientsToGenerate{ + { + grpcClientInterface: "ActionServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/actions", + }, + { + grpcClientInterface: "AttributesServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/attributes", + }, + { + grpcClientInterface: "AuthorizationServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/authorization", + }, + { + grpcClientInterface: "EntityResolutionServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/entityresolution", + }, + { + grpcClientInterface: "KeyAccessServerRegistryServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/kasregistry", + }, + { + grpcClientInterface: "KeyManagementServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/keymanagement", + }, + { + grpcClientInterface: "NamespaceServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/namespaces", + }, + { + grpcClientInterface: "RegisteredResourcesServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/registeredresources", + }, + { + grpcClientInterface: "ResourceMappingServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/resourcemapping", + }, + { + grpcClientInterface: "SubjectMappingServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/subjectmapping", + }, + { + grpcClientInterface: "UnsafeServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/unsafe", + }, + { + grpcClientInterface: "WellKnownServiceClient", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/wellknownconfiguration", + }, +} + +func Generate() error { + for _, client := range clientsToGenerateList { + slog.Info("Generating wrapper for", "interface", client.grpcClientInterface, "package", client.grpcPackagePath) + // Load the Go package using the import path + cfg := &packages.Config{ + Mode: packages.NeedName | + packages.NeedTypes | + packages.NeedTypesInfo | + packages.NeedSyntax | + packages.NeedCompiledGoFiles, + } + pkgs, err := packages.Load(cfg, client.grpcPackagePath) + if err != nil { + return fmt.Errorf("failed to load package %s: %w", client.grpcPackagePath, err) + } + if packages.PrintErrors(pkgs) > 0 { + return fmt.Errorf("errors loading package %s", client.grpcPackagePath) + } + found := false + err = nil + // Loop through the package and its files + for _, p := range pkgs { + for _, file := range p.Syntax { + ast.Inspect(file, func(n ast.Node) bool { + if found { + return false // skip rest of traversal + } + ts, ok := n.(*ast.TypeSpec) + if !ok { + return true + } + iface, ok := ts.Type.(*ast.InterfaceType) + if !ok { + return true + } + if ts.Name.Name == client.grpcClientInterface { + packageName := path.Base(client.grpcPackagePath) + code := generateWrapper(ts.Name.Name, iface, client.grpcPackagePath, packageName) + var currentDir string + currentDir, err = getCurrentFileDir() + outputPath := filepath.Join(currentDir, "..", "..", "..", "sdkconnect", packageName+".go") + err = os.WriteFile(outputPath, []byte(code), 0o644) //nolint:gosec // ignore G306 + found = true + return false // stop traversal + } + return true + }) + if found { + break + } + } + if found { + break + } + } + if !found { + return fmt.Errorf("interface %q not found in package %s", client.grpcClientInterface, client.grpcPackagePath) + } + if err != nil { + return fmt.Errorf("error writing file: %w", err) + } + } + return nil +} + +func getCurrentFileDir() (string, error) { + _, filename, _, ok := runtime.Caller(0) + if !ok { + return "", errors.New("could not get caller information") + } + return filepath.Dir(filename), nil +} + +// Helper function to get the method names of an interface +func getMethodNames(interfaceType *ast.InterfaceType) []string { + methodNames := []string{} + for _, method := range interfaceType.Methods.List { + if len(method.Names) > 0 { + methodNames = append(methodNames, method.Names[0].Name) + } + } + return methodNames +} + +// Generate wrapper code for the Connect RPC client interface +func generateWrapper(interfaceName string, interfaceType *ast.InterfaceType, packagePath string, packageName string) string { + // Get method names dynamically from the interface + methods := getMethodNames(interfaceType) + connectPackageName := packageName + "connect" + + // Start generating the wrapper code + wrapperCode := fmt.Sprintf(`// Wrapper for %s (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "%s" + "%s" + "google.golang.org/grpc" +) + +type %sConnectWrapper struct { + %s.%s +} + +func New%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *%sConnectWrapper { + return &%sConnectWrapper{%s: %s.New%s(httpClient, baseURL, opts...)} +} +`, + interfaceName, + packagePath, + packagePath+"/"+connectPackageName, + interfaceName, + connectPackageName, + interfaceName, + interfaceName, + interfaceName, + interfaceName, + interfaceName, + connectPackageName, + interfaceName) + + // Now generate a wrapper function for each method in the interface + for _, method := range methods { + wrapperCode += generateWrapperMethod(interfaceName, method, packageName) + } + + // Output the generated wrapper code + return wrapperCode +} + +// Generate the wrapper method for a specific method in the interface +func generateWrapperMethod(interfaceName, methodName, packageName string) string { + return fmt.Sprintf(` +func (w *%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest, _ ...grpc.CallOption) (*%s.%sResponse, error) { + // Wrap Connect RPC client request + res, err := w.%s.%s(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} +`, interfaceName, methodName, packageName, methodName, packageName, methodName, interfaceName, methodName) +} diff --git a/sdk/sdk.go b/sdk/sdk.go index 00749f0efc..81bf550117 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -15,24 +15,25 @@ import ( "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" - "github.com/opentdf/platform/protocol/go/authorization/authorizationconnect" - "github.com/opentdf/platform/protocol/go/entityresolution/entityresolutionconnect" + "github.com/opentdf/platform/protocol/go/authorization" + "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" - "github.com/opentdf/platform/protocol/go/policy/actions/actionsconnect" - "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" - "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" - "github.com/opentdf/platform/protocol/go/policy/keymanagement/keymanagementconnect" - "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" - "github.com/opentdf/platform/protocol/go/policy/registeredresources/registeredresourcesconnect" - "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" - "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" - "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" + "github.com/opentdf/platform/protocol/go/policy/actions" + "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/protocol/go/policy/keymanagement" + "github.com/opentdf/platform/protocol/go/policy/namespaces" + "github.com/opentdf/platform/protocol/go/policy/registeredresources" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping" + "github.com/opentdf/platform/protocol/go/policy/unsafe" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/httputil" "github.com/opentdf/platform/sdk/internal/archive" + "github.com/opentdf/platform/sdk/sdkconnect" "github.com/xeipuuv/gojsonschema" healthpb "google.golang.org/grpc/health/grpc_health_v1" ) @@ -65,18 +66,18 @@ type SDK struct { *collectionStore conn *ConnectRPCConnection tokenSource auth.AccessTokenSource - Actions actionsconnect.ActionServiceClient - Attributes attributesconnect.AttributesServiceClient - Authorization authorizationconnect.AuthorizationServiceClient - EntityResoution entityresolutionconnect.EntityResolutionServiceClient - KeyAccessServerRegistry kasregistryconnect.KeyAccessServerRegistryServiceClient - Namespaces namespacesconnect.NamespaceServiceClient - RegisteredResources registeredresourcesconnect.RegisteredResourcesServiceClient - ResourceMapping resourcemappingconnect.ResourceMappingServiceClient - SubjectMapping subjectmappingconnect.SubjectMappingServiceClient - Unsafe unsafeconnect.UnsafeServiceClient - KeyManagement keymanagementconnect.KeyManagementServiceClient - wellknownConfiguration wellknownconfigurationconnect.WellKnownServiceClient + Actions actions.ActionServiceClient + Attributes attributes.AttributesServiceClient + Authorization authorization.AuthorizationServiceClient + EntityResoution entityresolution.EntityResolutionServiceClient + KeyAccessServerRegistry kasregistry.KeyAccessServerRegistryServiceClient + Namespaces namespaces.NamespaceServiceClient + RegisteredResources registeredresources.RegisteredResourcesServiceClient + ResourceMapping resourcemapping.ResourceMappingServiceClient + SubjectMapping subjectmapping.SubjectMappingServiceClient + Unsafe unsafe.UnsafeServiceClient + KeyManagement keymanagement.KeyManagementServiceClient + wellknownConfiguration wellknownconfiguration.WellKnownServiceClient } func New(platformEndpoint string, opts ...Option) (*SDK, error) { @@ -194,18 +195,18 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { kasKeyCache: newKasKeyCache(), conn: &ConnectRPCConnection{Client: platformConn.Client, Endpoint: platformConn.Endpoint, Options: platformConn.Options}, tokenSource: accessTokenSource, - Actions: actionsconnect.NewActionServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - Attributes: attributesconnect.NewAttributesServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - Namespaces: namespacesconnect.NewNamespaceServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - RegisteredResources: registeredresourcesconnect.NewRegisteredResourcesServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - ResourceMapping: resourcemappingconnect.NewResourceMappingServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - SubjectMapping: subjectmappingconnect.NewSubjectMappingServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - Unsafe: unsafeconnect.NewUnsafeServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - KeyAccessServerRegistry: kasregistryconnect.NewKeyAccessServerRegistryServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - Authorization: authorizationconnect.NewAuthorizationServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - EntityResoution: entityresolutionconnect.NewEntityResolutionServiceClient(ersConn.Client, ersConn.Endpoint, ersConn.Options...), - KeyManagement: keymanagementconnect.NewKeyManagementServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - wellknownConfiguration: wellknownconfigurationconnect.NewWellKnownServiceClient(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Actions: sdkconnect.NewActionServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Attributes: sdkconnect.NewAttributesServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Namespaces: sdkconnect.NewNamespaceServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + RegisteredResources: sdkconnect.NewRegisteredResourcesServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + ResourceMapping: sdkconnect.NewResourceMappingServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + SubjectMapping: sdkconnect.NewSubjectMappingServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Unsafe: sdkconnect.NewUnsafeServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + KeyAccessServerRegistry: sdkconnect.NewKeyAccessServerRegistryServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + Authorization: sdkconnect.NewAuthorizationServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + EntityResoution: sdkconnect.NewEntityResolutionServiceClientConnectWrapper(ersConn.Client, ersConn.Endpoint, ersConn.Options...), + KeyManagement: sdkconnect.NewKeyManagementServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + wellknownConfiguration: sdkconnect.NewWellKnownServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), }, nil } diff --git a/sdk/sdkconnect/actions.go b/sdk/sdkconnect/actions.go new file mode 100644 index 0000000000..b60f7a8be9 --- /dev/null +++ b/sdk/sdkconnect/actions.go @@ -0,0 +1,63 @@ +// Wrapper for ActionServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/actions" + "github.com/opentdf/platform/protocol/go/policy/actions/actionsconnect" + "google.golang.org/grpc" +) + +type ActionServiceClientConnectWrapper struct { + actionsconnect.ActionServiceClient +} + +func NewActionServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *ActionServiceClientConnectWrapper { + return &ActionServiceClientConnectWrapper{ActionServiceClient: actionsconnect.NewActionServiceClient(httpClient, baseURL, opts...)} +} + +func (w *ActionServiceClientConnectWrapper) GetAction(ctx context.Context, req *actions.GetActionRequest, _ ...grpc.CallOption) (*actions.GetActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.GetAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) ListActions(ctx context.Context, req *actions.ListActionsRequest, _ ...grpc.CallOption) (*actions.ListActionsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.ListActions(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) CreateAction(ctx context.Context, req *actions.CreateActionRequest, _ ...grpc.CallOption) (*actions.CreateActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.CreateAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) UpdateAction(ctx context.Context, req *actions.UpdateActionRequest, _ ...grpc.CallOption) (*actions.UpdateActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.UpdateAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ActionServiceClientConnectWrapper) DeleteAction(ctx context.Context, req *actions.DeleteActionRequest, _ ...grpc.CallOption) (*actions.DeleteActionResponse, error) { + // Wrap Connect RPC client request + res, err := w.ActionServiceClient.DeleteAction(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/attributes.go b/sdk/sdkconnect/attributes.go new file mode 100644 index 0000000000..05027dff31 --- /dev/null +++ b/sdk/sdkconnect/attributes.go @@ -0,0 +1,189 @@ +// Wrapper for AttributesServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + "google.golang.org/grpc" +) + +type AttributesServiceClientConnectWrapper struct { + attributesconnect.AttributesServiceClient +} + +func NewAttributesServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *AttributesServiceClientConnectWrapper { + return &AttributesServiceClientConnectWrapper{AttributesServiceClient: attributesconnect.NewAttributesServiceClient(httpClient, baseURL, opts...)} +} + +func (w *AttributesServiceClientConnectWrapper) ListAttributes(ctx context.Context, req *attributes.ListAttributesRequest, _ ...grpc.CallOption) (*attributes.ListAttributesResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.ListAttributes(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) ListAttributeValues(ctx context.Context, req *attributes.ListAttributeValuesRequest, _ ...grpc.CallOption) (*attributes.ListAttributeValuesResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.ListAttributeValues(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) GetAttribute(ctx context.Context, req *attributes.GetAttributeRequest, _ ...grpc.CallOption) (*attributes.GetAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.GetAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) GetAttributeValuesByFqns(ctx context.Context, req *attributes.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValuesByFqnsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.GetAttributeValuesByFqns(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) CreateAttribute(ctx context.Context, req *attributes.CreateAttributeRequest, _ ...grpc.CallOption) (*attributes.CreateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.CreateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) UpdateAttribute(ctx context.Context, req *attributes.UpdateAttributeRequest, _ ...grpc.CallOption) (*attributes.UpdateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.UpdateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) DeactivateAttribute(ctx context.Context, req *attributes.DeactivateAttributeRequest, _ ...grpc.CallOption) (*attributes.DeactivateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.DeactivateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) GetAttributeValue(ctx context.Context, req *attributes.GetAttributeValueRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.GetAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) CreateAttributeValue(ctx context.Context, req *attributes.CreateAttributeValueRequest, _ ...grpc.CallOption) (*attributes.CreateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.CreateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) UpdateAttributeValue(ctx context.Context, req *attributes.UpdateAttributeValueRequest, _ ...grpc.CallOption) (*attributes.UpdateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.UpdateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) DeactivateAttributeValue(ctx context.Context, req *attributes.DeactivateAttributeValueRequest, _ ...grpc.CallOption) (*attributes.DeactivateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.DeactivateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToAttribute(ctx context.Context, req *attributes.AssignKeyAccessServerToAttributeRequest, _ ...grpc.CallOption) (*attributes.AssignKeyAccessServerToAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignKeyAccessServerToAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromAttribute(ctx context.Context, req *attributes.RemoveKeyAccessServerFromAttributeRequest, _ ...grpc.CallOption) (*attributes.RemoveKeyAccessServerFromAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemoveKeyAccessServerFromAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToValue(ctx context.Context, req *attributes.AssignKeyAccessServerToValueRequest, _ ...grpc.CallOption) (*attributes.AssignKeyAccessServerToValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignKeyAccessServerToValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromValue(ctx context.Context, req *attributes.RemoveKeyAccessServerFromValueRequest, _ ...grpc.CallOption) (*attributes.RemoveKeyAccessServerFromValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemoveKeyAccessServerFromValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToAttribute(ctx context.Context, req *attributes.AssignPublicKeyToAttributeRequest, _ ...grpc.CallOption) (*attributes.AssignPublicKeyToAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignPublicKeyToAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromAttribute(ctx context.Context, req *attributes.RemovePublicKeyFromAttributeRequest, _ ...grpc.CallOption) (*attributes.RemovePublicKeyFromAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemovePublicKeyFromAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToValue(ctx context.Context, req *attributes.AssignPublicKeyToValueRequest, _ ...grpc.CallOption) (*attributes.AssignPublicKeyToValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.AssignPublicKeyToValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromValue(ctx context.Context, req *attributes.RemovePublicKeyFromValueRequest, _ ...grpc.CallOption) (*attributes.RemovePublicKeyFromValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.AttributesServiceClient.RemovePublicKeyFromValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/authorization.go b/sdk/sdkconnect/authorization.go new file mode 100644 index 0000000000..5ef2d6666e --- /dev/null +++ b/sdk/sdkconnect/authorization.go @@ -0,0 +1,45 @@ +// Wrapper for AuthorizationServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/authorization" + "github.com/opentdf/platform/protocol/go/authorization/authorizationconnect" + "google.golang.org/grpc" +) + +type AuthorizationServiceClientConnectWrapper struct { + authorizationconnect.AuthorizationServiceClient +} + +func NewAuthorizationServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *AuthorizationServiceClientConnectWrapper { + return &AuthorizationServiceClientConnectWrapper{AuthorizationServiceClient: authorizationconnect.NewAuthorizationServiceClient(httpClient, baseURL, opts...)} +} + +func (w *AuthorizationServiceClientConnectWrapper) GetDecisions(ctx context.Context, req *authorization.GetDecisionsRequest, _ ...grpc.CallOption) (*authorization.GetDecisionsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisions(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientConnectWrapper) GetDecisionsByToken(ctx context.Context, req *authorization.GetDecisionsByTokenRequest, _ ...grpc.CallOption) (*authorization.GetDecisionsByTokenResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisionsByToken(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientConnectWrapper) GetEntitlements(ctx context.Context, req *authorization.GetEntitlementsRequest, _ ...grpc.CallOption) (*authorization.GetEntitlementsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetEntitlements(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/entityresolution.go b/sdk/sdkconnect/entityresolution.go new file mode 100644 index 0000000000..af3132934f --- /dev/null +++ b/sdk/sdkconnect/entityresolution.go @@ -0,0 +1,36 @@ +// Wrapper for EntityResolutionServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/entityresolution" + "github.com/opentdf/platform/protocol/go/entityresolution/entityresolutionconnect" + "google.golang.org/grpc" +) + +type EntityResolutionServiceClientConnectWrapper struct { + entityresolutionconnect.EntityResolutionServiceClient +} + +func NewEntityResolutionServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *EntityResolutionServiceClientConnectWrapper { + return &EntityResolutionServiceClientConnectWrapper{EntityResolutionServiceClient: entityresolutionconnect.NewEntityResolutionServiceClient(httpClient, baseURL, opts...)} +} + +func (w *EntityResolutionServiceClientConnectWrapper) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest, _ ...grpc.CallOption) (*entityresolution.ResolveEntitiesResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.ResolveEntities(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *EntityResolutionServiceClientConnectWrapper) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest, _ ...grpc.CallOption) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.CreateEntityChainFromJwt(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/kasregistry.go b/sdk/sdkconnect/kasregistry.go new file mode 100644 index 0000000000..8dc14f344a --- /dev/null +++ b/sdk/sdkconnect/kasregistry.go @@ -0,0 +1,117 @@ +// Wrapper for KeyAccessServerRegistryServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" + "google.golang.org/grpc" +) + +type KeyAccessServerRegistryServiceClientConnectWrapper struct { + kasregistryconnect.KeyAccessServerRegistryServiceClient +} + +func NewKeyAccessServerRegistryServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *KeyAccessServerRegistryServiceClientConnectWrapper { + return &KeyAccessServerRegistryServiceClientConnectWrapper{KeyAccessServerRegistryServiceClient: kasregistryconnect.NewKeyAccessServerRegistryServiceClient(httpClient, baseURL, opts...)} +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServers(ctx context.Context, req *kasregistry.ListKeyAccessServersRequest, _ ...grpc.CallOption) (*kasregistry.ListKeyAccessServersResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.ListKeyAccessServers(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKeyAccessServer(ctx context.Context, req *kasregistry.GetKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.GetKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.GetKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKeyAccessServer(ctx context.Context, req *kasregistry.CreateKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.CreateKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.CreateKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKeyAccessServer(ctx context.Context, req *kasregistry.UpdateKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.UpdateKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.UpdateKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) DeleteKeyAccessServer(ctx context.Context, req *kasregistry.DeleteKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.DeleteKeyAccessServerResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.DeleteKeyAccessServer(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServerGrants(ctx context.Context, req *kasregistry.ListKeyAccessServerGrantsRequest, _ ...grpc.CallOption) (*kasregistry.ListKeyAccessServerGrantsResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.ListKeyAccessServerGrants(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKey(ctx context.Context, req *kasregistry.CreateKeyRequest, _ ...grpc.CallOption) (*kasregistry.CreateKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.CreateKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKey(ctx context.Context, req *kasregistry.GetKeyRequest, _ ...grpc.CallOption) (*kasregistry.GetKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.GetKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest, _ ...grpc.CallOption) (*kasregistry.ListKeysResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.ListKeys(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKey(ctx context.Context, req *kasregistry.UpdateKeyRequest, _ ...grpc.CallOption) (*kasregistry.UpdateKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.UpdateKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) RotateKey(ctx context.Context, req *kasregistry.RotateKeyRequest, _ ...grpc.CallOption) (*kasregistry.RotateKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyAccessServerRegistryServiceClient.RotateKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/keymanagement.go b/sdk/sdkconnect/keymanagement.go new file mode 100644 index 0000000000..cb2c2ad8a4 --- /dev/null +++ b/sdk/sdkconnect/keymanagement.go @@ -0,0 +1,63 @@ +// Wrapper for KeyManagementServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/keymanagement" + "github.com/opentdf/platform/protocol/go/policy/keymanagement/keymanagementconnect" + "google.golang.org/grpc" +) + +type KeyManagementServiceClientConnectWrapper struct { + keymanagementconnect.KeyManagementServiceClient +} + +func NewKeyManagementServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *KeyManagementServiceClientConnectWrapper { + return &KeyManagementServiceClientConnectWrapper{KeyManagementServiceClient: keymanagementconnect.NewKeyManagementServiceClient(httpClient, baseURL, opts...)} +} + +func (w *KeyManagementServiceClientConnectWrapper) CreateProviderConfig(ctx context.Context, req *keymanagement.CreateProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.CreateProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.CreateProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) GetProviderConfig(ctx context.Context, req *keymanagement.GetProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.GetProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.GetProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) ListProviderConfigs(ctx context.Context, req *keymanagement.ListProviderConfigsRequest, _ ...grpc.CallOption) (*keymanagement.ListProviderConfigsResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.ListProviderConfigs(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) UpdateProviderConfig(ctx context.Context, req *keymanagement.UpdateProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.UpdateProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.UpdateProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *KeyManagementServiceClientConnectWrapper) DeleteProviderConfig(ctx context.Context, req *keymanagement.DeleteProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.DeleteProviderConfigResponse, error) { + // Wrap Connect RPC client request + res, err := w.KeyManagementServiceClient.DeleteProviderConfig(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/namespaces.go b/sdk/sdkconnect/namespaces.go new file mode 100644 index 0000000000..1b01938c9c --- /dev/null +++ b/sdk/sdkconnect/namespaces.go @@ -0,0 +1,99 @@ +// Wrapper for NamespaceServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/namespaces" + "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" + "google.golang.org/grpc" +) + +type NamespaceServiceClientConnectWrapper struct { + namespacesconnect.NamespaceServiceClient +} + +func NewNamespaceServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *NamespaceServiceClientConnectWrapper { + return &NamespaceServiceClientConnectWrapper{NamespaceServiceClient: namespacesconnect.NewNamespaceServiceClient(httpClient, baseURL, opts...)} +} + +func (w *NamespaceServiceClientConnectWrapper) GetNamespace(ctx context.Context, req *namespaces.GetNamespaceRequest, _ ...grpc.CallOption) (*namespaces.GetNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.GetNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) ListNamespaces(ctx context.Context, req *namespaces.ListNamespacesRequest, _ ...grpc.CallOption) (*namespaces.ListNamespacesResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.ListNamespaces(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) CreateNamespace(ctx context.Context, req *namespaces.CreateNamespaceRequest, _ ...grpc.CallOption) (*namespaces.CreateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.CreateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) UpdateNamespace(ctx context.Context, req *namespaces.UpdateNamespaceRequest, _ ...grpc.CallOption) (*namespaces.UpdateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.UpdateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) DeactivateNamespace(ctx context.Context, req *namespaces.DeactivateNamespaceRequest, _ ...grpc.CallOption) (*namespaces.DeactivateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.DeactivateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) AssignKeyAccessServerToNamespace(ctx context.Context, req *namespaces.AssignKeyAccessServerToNamespaceRequest, _ ...grpc.CallOption) (*namespaces.AssignKeyAccessServerToNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.AssignKeyAccessServerToNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) RemoveKeyAccessServerFromNamespace(ctx context.Context, req *namespaces.RemoveKeyAccessServerFromNamespaceRequest, _ ...grpc.CallOption) (*namespaces.RemoveKeyAccessServerFromNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.RemoveKeyAccessServerFromNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) AssignPublicKeyToNamespace(ctx context.Context, req *namespaces.AssignPublicKeyToNamespaceRequest, _ ...grpc.CallOption) (*namespaces.AssignPublicKeyToNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.AssignPublicKeyToNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *NamespaceServiceClientConnectWrapper) RemovePublicKeyFromNamespace(ctx context.Context, req *namespaces.RemovePublicKeyFromNamespaceRequest, _ ...grpc.CallOption) (*namespaces.RemovePublicKeyFromNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.NamespaceServiceClient.RemovePublicKeyFromNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/registeredresources.go b/sdk/sdkconnect/registeredresources.go new file mode 100644 index 0000000000..7974522ede --- /dev/null +++ b/sdk/sdkconnect/registeredresources.go @@ -0,0 +1,117 @@ +// Wrapper for RegisteredResourcesServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/registeredresources" + "github.com/opentdf/platform/protocol/go/policy/registeredresources/registeredresourcesconnect" + "google.golang.org/grpc" +) + +type RegisteredResourcesServiceClientConnectWrapper struct { + registeredresourcesconnect.RegisteredResourcesServiceClient +} + +func NewRegisteredResourcesServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *RegisteredResourcesServiceClientConnectWrapper { + return &RegisteredResourcesServiceClientConnectWrapper{RegisteredResourcesServiceClient: registeredresourcesconnect.NewRegisteredResourcesServiceClient(httpClient, baseURL, opts...)} +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResource(ctx context.Context, req *registeredresources.CreateRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.CreateRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.CreateRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResource(ctx context.Context, req *registeredresources.GetRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.GetRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.GetRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResources(ctx context.Context, req *registeredresources.ListRegisteredResourcesRequest, _ ...grpc.CallOption) (*registeredresources.ListRegisteredResourcesResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.ListRegisteredResources(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResource(ctx context.Context, req *registeredresources.UpdateRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.UpdateRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.UpdateRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResource(ctx context.Context, req *registeredresources.DeleteRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.DeleteRegisteredResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.DeleteRegisteredResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResourceValue(ctx context.Context, req *registeredresources.CreateRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.CreateRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.CreateRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValue(ctx context.Context, req *registeredresources.GetRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.GetRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.GetRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValuesByFQNs(ctx context.Context, req *registeredresources.GetRegisteredResourceValuesByFQNsRequest, _ ...grpc.CallOption) (*registeredresources.GetRegisteredResourceValuesByFQNsResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.GetRegisteredResourceValuesByFQNs(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResourceValues(ctx context.Context, req *registeredresources.ListRegisteredResourceValuesRequest, _ ...grpc.CallOption) (*registeredresources.ListRegisteredResourceValuesResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.ListRegisteredResourceValues(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResourceValue(ctx context.Context, req *registeredresources.UpdateRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.UpdateRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.UpdateRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResourceValue(ctx context.Context, req *registeredresources.DeleteRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.DeleteRegisteredResourceValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.RegisteredResourcesServiceClient.DeleteRegisteredResourceValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/resourcemapping.go b/sdk/sdkconnect/resourcemapping.go new file mode 100644 index 0000000000..17b04ce236 --- /dev/null +++ b/sdk/sdkconnect/resourcemapping.go @@ -0,0 +1,117 @@ +// Wrapper for ResourceMappingServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping" + "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" + "google.golang.org/grpc" +) + +type ResourceMappingServiceClientConnectWrapper struct { + resourcemappingconnect.ResourceMappingServiceClient +} + +func NewResourceMappingServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *ResourceMappingServiceClientConnectWrapper { + return &ResourceMappingServiceClientConnectWrapper{ResourceMappingServiceClient: resourcemappingconnect.NewResourceMappingServiceClient(httpClient, baseURL, opts...)} +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingGroups(ctx context.Context, req *resourcemapping.ListResourceMappingGroupsRequest, _ ...grpc.CallOption) (*resourcemapping.ListResourceMappingGroupsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.ListResourceMappingGroups(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMappingGroup(ctx context.Context, req *resourcemapping.GetResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.GetResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.GetResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMappingGroup(ctx context.Context, req *resourcemapping.CreateResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.CreateResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.CreateResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMappingGroup(ctx context.Context, req *resourcemapping.UpdateResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.UpdateResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.UpdateResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMappingGroup(ctx context.Context, req *resourcemapping.DeleteResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.DeleteResourceMappingGroupResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.DeleteResourceMappingGroup(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappings(ctx context.Context, req *resourcemapping.ListResourceMappingsRequest, _ ...grpc.CallOption) (*resourcemapping.ListResourceMappingsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.ListResourceMappings(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingsByGroupFqns(ctx context.Context, req *resourcemapping.ListResourceMappingsByGroupFqnsRequest, _ ...grpc.CallOption) (*resourcemapping.ListResourceMappingsByGroupFqnsResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.ListResourceMappingsByGroupFqns(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMapping(ctx context.Context, req *resourcemapping.GetResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.GetResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.GetResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMapping(ctx context.Context, req *resourcemapping.CreateResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.CreateResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.CreateResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMapping(ctx context.Context, req *resourcemapping.UpdateResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.UpdateResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.UpdateResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMapping(ctx context.Context, req *resourcemapping.DeleteResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.DeleteResourceMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.ResourceMappingServiceClient.DeleteResourceMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/subjectmapping.go b/sdk/sdkconnect/subjectmapping.go new file mode 100644 index 0000000000..9345fbf442 --- /dev/null +++ b/sdk/sdkconnect/subjectmapping.go @@ -0,0 +1,126 @@ +// Wrapper for SubjectMappingServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping" + "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" + "google.golang.org/grpc" +) + +type SubjectMappingServiceClientConnectWrapper struct { + subjectmappingconnect.SubjectMappingServiceClient +} + +func NewSubjectMappingServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *SubjectMappingServiceClientConnectWrapper { + return &SubjectMappingServiceClientConnectWrapper{SubjectMappingServiceClient: subjectmappingconnect.NewSubjectMappingServiceClient(httpClient, baseURL, opts...)} +} + +func (w *SubjectMappingServiceClientConnectWrapper) MatchSubjectMappings(ctx context.Context, req *subjectmapping.MatchSubjectMappingsRequest, _ ...grpc.CallOption) (*subjectmapping.MatchSubjectMappingsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.MatchSubjectMappings(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectMappings(ctx context.Context, req *subjectmapping.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*subjectmapping.ListSubjectMappingsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.ListSubjectMappings(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectMapping(ctx context.Context, req *subjectmapping.GetSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.GetSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.GetSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectMapping(ctx context.Context, req *subjectmapping.CreateSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.CreateSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.CreateSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectMapping(ctx context.Context, req *subjectmapping.UpdateSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.UpdateSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.UpdateSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectMapping(ctx context.Context, req *subjectmapping.DeleteSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.DeleteSubjectMappingResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.DeleteSubjectMapping(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectConditionSets(ctx context.Context, req *subjectmapping.ListSubjectConditionSetsRequest, _ ...grpc.CallOption) (*subjectmapping.ListSubjectConditionSetsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.ListSubjectConditionSets(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectConditionSet(ctx context.Context, req *subjectmapping.GetSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.GetSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.GetSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectConditionSet(ctx context.Context, req *subjectmapping.CreateSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.CreateSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.CreateSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectConditionSet(ctx context.Context, req *subjectmapping.UpdateSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.UpdateSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.UpdateSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectConditionSet(ctx context.Context, req *subjectmapping.DeleteSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.DeleteSubjectConditionSetResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.DeleteSubjectConditionSet(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *SubjectMappingServiceClientConnectWrapper) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *subjectmapping.DeleteAllUnmappedSubjectConditionSetsRequest, _ ...grpc.CallOption) (*subjectmapping.DeleteAllUnmappedSubjectConditionSetsResponse, error) { + // Wrap Connect RPC client request + res, err := w.SubjectMappingServiceClient.DeleteAllUnmappedSubjectConditionSets(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/unsafe.go b/sdk/sdkconnect/unsafe.go new file mode 100644 index 0000000000..792cb80836 --- /dev/null +++ b/sdk/sdkconnect/unsafe.go @@ -0,0 +1,108 @@ +// Wrapper for UnsafeServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/policy/unsafe" + "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" + "google.golang.org/grpc" +) + +type UnsafeServiceClientConnectWrapper struct { + unsafeconnect.UnsafeServiceClient +} + +func NewUnsafeServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *UnsafeServiceClientConnectWrapper { + return &UnsafeServiceClientConnectWrapper{UnsafeServiceClient: unsafeconnect.NewUnsafeServiceClient(httpClient, baseURL, opts...)} +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateNamespace(ctx context.Context, req *unsafe.UnsafeUpdateNamespaceRequest, _ ...grpc.CallOption) (*unsafe.UnsafeUpdateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeUpdateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateNamespace(ctx context.Context, req *unsafe.UnsafeReactivateNamespaceRequest, _ ...grpc.CallOption) (*unsafe.UnsafeReactivateNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeReactivateNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteNamespace(ctx context.Context, req *unsafe.UnsafeDeleteNamespaceRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteNamespaceResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteNamespace(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttribute(ctx context.Context, req *unsafe.UnsafeUpdateAttributeRequest, _ ...grpc.CallOption) (*unsafe.UnsafeUpdateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeUpdateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttribute(ctx context.Context, req *unsafe.UnsafeReactivateAttributeRequest, _ ...grpc.CallOption) (*unsafe.UnsafeReactivateAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeReactivateAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttribute(ctx context.Context, req *unsafe.UnsafeDeleteAttributeRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteAttributeResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteAttribute(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttributeValue(ctx context.Context, req *unsafe.UnsafeUpdateAttributeValueRequest, _ ...grpc.CallOption) (*unsafe.UnsafeUpdateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeUpdateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttributeValue(ctx context.Context, req *unsafe.UnsafeReactivateAttributeValueRequest, _ ...grpc.CallOption) (*unsafe.UnsafeReactivateAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeReactivateAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttributeValue(ctx context.Context, req *unsafe.UnsafeDeleteAttributeValueRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteAttributeValueResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteAttributeValue(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteKasKey(ctx context.Context, req *unsafe.UnsafeDeleteKasKeyRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteKasKeyResponse, error) { + // Wrap Connect RPC client request + res, err := w.UnsafeServiceClient.UnsafeDeleteKasKey(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/sdkconnect/wellknownconfiguration.go b/sdk/sdkconnect/wellknownconfiguration.go new file mode 100644 index 0000000000..3b7f4822a8 --- /dev/null +++ b/sdk/sdkconnect/wellknownconfiguration.go @@ -0,0 +1,27 @@ +// Wrapper for WellKnownServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/wellknownconfiguration" + "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" + "google.golang.org/grpc" +) + +type WellKnownServiceClientConnectWrapper struct { + wellknownconfigurationconnect.WellKnownServiceClient +} + +func NewWellKnownServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *WellKnownServiceClientConnectWrapper { + return &WellKnownServiceClientConnectWrapper{WellKnownServiceClient: wellknownconfigurationconnect.NewWellKnownServiceClient(httpClient, baseURL, opts...)} +} + +func (w *WellKnownServiceClientConnectWrapper) GetWellKnownConfiguration(ctx context.Context, req *wellknownconfiguration.GetWellKnownConfigurationRequest, _ ...grpc.CallOption) (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) { + // Wrap Connect RPC client request + res, err := w.WellKnownServiceClient.GetWellKnownConfiguration(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} diff --git a/sdk/tdf.go b/sdk/tdf.go index 049611b0df..a718ba7272 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -19,7 +19,6 @@ import ( "github.com/Masterminds/semver/v3" "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/policy/kasregistry" - "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" "github.com/google/uuid" "github.com/opentdf/platform/lib/ocrypto" @@ -657,13 +656,13 @@ func createPolicyObject(attributes []AttributeValueFQN) (PolicyObject, error) { return policyObj, nil } -func allowListFromKASRegistry(ctx context.Context, kasRegistryClient kasregistryconnect.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { - kases, err := kasRegistryClient.ListKeyAccessServers(ctx, connect.NewRequest(&kasregistry.ListKeyAccessServersRequest{})) +func allowListFromKASRegistry(ctx context.Context, kasRegistryClient kasregistry.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { + kases, err := kasRegistryClient.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) if err != nil { return nil, fmt.Errorf("kasregistry.ListKeyAccessServers failed: %w", err) } kasAllowlist := AllowList{} - for _, kas := range kases.Msg.GetKeyAccessServers() { + for _, kas := range kases.GetKeyAccessServers() { err = kasAllowlist.Add(kas.GetUri()) if err != nil { return nil, fmt.Errorf("kasAllowlist.Add failed: %w", err) diff --git a/service/authorization/authorization.go b/service/authorization/authorization.go index cb3620003e..3ed2e67650 100644 --- a/service/authorization/authorization.go +++ b/service/authorization/authorization.go @@ -165,7 +165,7 @@ func (as *AuthorizationService) GetDecisionsByToken(ctx context.Context, req *co // for each token decision request for _, tdr := range req.Msg.GetDecisionRequests() { - ecResp, err := as.sdk.EntityResoution.CreateEntityChainFromJwt(ctx, connect.NewRequest(&entityresolution.CreateEntityChainFromJwtRequest{Tokens: tdr.GetTokens()})) + ecResp, err := as.sdk.EntityResoution.CreateEntityChainFromJwt(ctx, &entityresolution.CreateEntityChainFromJwtRequest{Tokens: tdr.GetTokens()}) if err != nil { as.logger.Error("Error calling ERS to get entity chains from jwts") return nil, err @@ -174,7 +174,7 @@ func (as *AuthorizationService) GetDecisionsByToken(ctx context.Context, req *co // form a decision request for the token decision request decisionsRequests = append(decisionsRequests, &authorization.DecisionRequest{ Actions: tdr.GetActions(), - EntityChains: ecResp.Msg.GetEntityChains(), + EntityChains: ecResp.GetEntityChains(), ResourceAttributes: tdr.GetResourceAttributes(), }) } @@ -303,19 +303,19 @@ func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connec // If quantity of attributes exceeds maximum list pagination, all are needed to determine entitlements for { - listed, err := as.sdk.Attributes.ListAttributes(ctx, connect.NewRequest(&attr.ListAttributesRequest{ + listed, err := as.sdk.Attributes.ListAttributes(ctx, &attr.ListAttributesRequest{ State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, Pagination: &policy.PageRequest{ Offset: nextOffset, }, - })) + }) if err != nil { as.logger.ErrorContext(ctx, "failed to list attributes", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list attributes")) } - nextOffset = listed.Msg.GetPagination().GetNextOffset() - attrsList = append(attrsList, listed.Msg.GetAttributes()...) + nextOffset = listed.GetPagination().GetNextOffset() + attrsList = append(attrsList, listed.GetAttributes()...) // offset becomes zero when list is exhausted if nextOffset <= 0 { @@ -326,18 +326,18 @@ func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connec // If quantity of subject mappings exceeds maximum list pagination, all are needed to determine entitlements nextOffset = 0 for { - listed, err := as.sdk.SubjectMapping.ListSubjectMappings(ctx, connect.NewRequest(&subjectmapping.ListSubjectMappingsRequest{ + listed, err := as.sdk.SubjectMapping.ListSubjectMappings(ctx, &subjectmapping.ListSubjectMappingsRequest{ Pagination: &policy.PageRequest{ Offset: nextOffset, }, - })) + }) if err != nil { as.logger.ErrorContext(ctx, "failed to list subject mappings", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list subject mappings")) } - nextOffset = listed.Msg.GetPagination().GetNextOffset() - subjectMappingsList = append(subjectMappingsList, listed.Msg.GetSubjectMappings()...) + nextOffset = listed.GetPagination().GetNextOffset() + subjectMappingsList = append(subjectMappingsList, listed.GetSubjectMappings()...) // offset becomes zero when list is exhausted if nextOffset <= 0 { @@ -365,14 +365,14 @@ func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connec } // call ERS on all entities - ersResp, err := as.sdk.EntityResoution.ResolveEntities(ctx, connect.NewRequest(&entityresolution.ResolveEntitiesRequest{Entities: req.Msg.GetEntities()})) + ersResp, err := as.sdk.EntityResoution.ResolveEntities(ctx, &entityresolution.ResolveEntitiesRequest{Entities: req.Msg.GetEntities()}) if err != nil { as.logger.ErrorContext(ctx, "error calling ERS to resolve entities", "entities", req.Msg.GetEntities()) return nil, err } // call rego on all entities - in, err := entitlements.OpaInput(subjectMappings, ersResp.Msg) + in, err := entitlements.OpaInput(subjectMappings, ersResp) if err != nil { as.logger.ErrorContext(ctx, "failed to build rego input", slog.String("error", err.Error())) return nil, connect.NewError(connect.CodeInternal, errors.New("failed to build rego input")) @@ -726,16 +726,16 @@ func retrieveAttributeDefinitions(ctx context.Context, attrFqns []string, sdk *o return make(map[string]*attr.GetAttributeValuesByFqnsResponse_AttributeAndValue), nil } - resp, err := sdk.Attributes.GetAttributeValuesByFqns(ctx, connect.NewRequest(&attr.GetAttributeValuesByFqnsRequest{ + resp, err := sdk.Attributes.GetAttributeValuesByFqns(ctx, &attr.GetAttributeValuesByFqnsRequest{ WithValue: &policy.AttributeValueSelector{ WithSubjectMaps: false, }, Fqns: attrFqns, - })) + }) if err != nil { return nil, err } - return resp.Msg.GetFqnAttributeValues(), nil + return resp.GetFqnAttributeValues(), nil } func getComprehensiveHierarchy(attributesMap map[string]*policy.Attribute, avf *attr.GetAttributeValuesByFqnsResponse, entitlement string, as *AuthorizationService, entitlements []string) []string { diff --git a/service/authorization/authorization_test.go b/service/authorization/authorization_test.go index 737fca31ac..26e8152cb8 100644 --- a/service/authorization/authorization_test.go +++ b/service/authorization/authorization_test.go @@ -16,14 +16,13 @@ import ( "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" attr "github.com/opentdf/platform/protocol/go/policy/attributes" - attrconnect "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" - smconnect "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" otdf "github.com/opentdf/platform/sdk" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" @@ -48,15 +47,15 @@ var ( ) type myAttributesClient struct { - attrconnect.AttributesServiceClient + attr.AttributesServiceClient } -func (*myAttributesClient) ListAttributes(_ context.Context, _ *connect.Request[attr.ListAttributesRequest]) (*connect.Response[attr.ListAttributesResponse], error) { - return connect.NewResponse(&listAttributeResp), errListAttributes +func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { + return &listAttributeResp, errListAttributes } -func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *connect.Request[attr.GetAttributeValuesByFqnsRequest]) (*connect.Response[attr.GetAttributeValuesByFqnsResponse], error) { - return connect.NewResponse(&getAttributesByValueFqnsResponse), errGetAttributesByValueFqns +func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attr.GetAttributeValuesByFqnsResponse, error) { + return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns } type myERSClient struct { @@ -64,23 +63,23 @@ type myERSClient struct { } type mySubjectMappingClient struct { - smconnect.SubjectMappingServiceClient + sm.SubjectMappingServiceClient } type paginatedMockSubjectMappingClient struct { - smconnect.SubjectMappingServiceClient + sm.SubjectMappingServiceClient } -func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *connect.Request[sm.ListSubjectMappingsRequest]) (*connect.Response[sm.ListSubjectMappingsResponse], error) { - return connect.NewResponse(&listSubjectMappings), nil +func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { + return &listSubjectMappings, nil } -func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *connect.Request[entityresolution.CreateEntityChainFromJwtRequest]) (*connect.Response[entityresolution.CreateEntityChainFromJwtResponse], error) { - return connect.NewResponse(&createEntityChainResp), nil +func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest, _ ...grpc.CallOption) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + return &createEntityChainResp, nil } -func (*myERSClient) ResolveEntities(_ context.Context, _ *connect.Request[entityresolution.ResolveEntitiesRequest]) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) { - return connect.NewResponse(&resolveEntitiesResp), nil +func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest, _ ...grpc.CallOption) (*entityresolution.ResolveEntitiesResponse, error) { + return &resolveEntitiesResp, nil } var ( @@ -88,7 +87,7 @@ var ( smListCallCount = 0 ) -func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *connect.Request[sm.ListSubjectMappingsRequest]) (*connect.Response[sm.ListSubjectMappingsResponse], error) { +func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { smListCallCount++ // simulate paginated list and policy LIST behavior if smPaginationOffset > 0 { @@ -99,13 +98,13 @@ func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, }, } smPaginationOffset = 0 - return connect.NewResponse(rsp), nil + return rsp, nil } - return connect.NewResponse(&listSubjectMappings), nil + return &listSubjectMappings, nil } type paginatedMockAttributesClient struct { - attrconnect.AttributesServiceClient + attr.AttributesServiceClient } var ( @@ -113,7 +112,7 @@ var ( attrListCallCount = 0 ) -func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *connect.Request[attr.ListAttributesRequest]) (*connect.Response[attr.ListAttributesResponse], error) { +func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { attrListCallCount++ // simulate paginated list and policy LIST behavior if attrPaginationOffset > 0 { @@ -124,9 +123,9 @@ func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *conne }, } attrPaginationOffset = 0 - return connect.NewResponse(rsp), nil + return rsp, nil } - return connect.NewResponse(&listAttributeResp), nil + return &listAttributeResp, nil } func TestGetComprehensiveHierarchy(t *testing.T) { @@ -449,9 +448,11 @@ func Test_GetDecisions_AllOf_Fail(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -547,9 +548,11 @@ func Test_GetDecisionsAllOfWithEnvironmental_Pass(t *testing.T) { } as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -642,9 +645,11 @@ func Test_GetDecisionsAllOfWithEnvironmental_Fail(t *testing.T) { } as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -715,9 +720,11 @@ func Test_GetEntitlementsSimple(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -787,9 +794,11 @@ func Test_GetEntitlementsFqnCasing(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -864,7 +873,8 @@ func Test_GetEntitlements_HandlesPagination(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ + logger: logger, + sdk: &otdf.SDK{ SubjectMapping: &paginatedMockSubjectMappingClient{}, Attributes: &paginatedMockAttributesClient{}, EntityResoution: &myERSClient{}, @@ -955,9 +965,11 @@ func Test_GetEntitlementsWithComprehensiveHierarchy(t *testing.T) { prepared, err := rego.PrepareForEval(t.Context()) require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -1195,9 +1207,11 @@ func Test_GetDecisions_RA_FQN_Edge_Cases(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -1401,9 +1415,11 @@ func Test_GetDecisionsAllOf_Pass_EC_RA_Length_Mismatch(t *testing.T) { } as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), @@ -1678,9 +1694,11 @@ func Test_GetDecisions_Empty_EC_RA(t *testing.T) { require.NoError(t, err) as := AuthorizationService{ - logger: logger, sdk: &otdf.SDK{ - SubjectMapping: &mySubjectMappingClient{}, - Attributes: &myAttributesClient{}, EntityResoution: &myERSClient{}, + logger: logger, + sdk: &otdf.SDK{ + SubjectMapping: &mySubjectMappingClient{}, + Attributes: &myAttributesClient{}, + EntityResoution: &myERSClient{}, }, eval: prepared, Tracer: noop.NewTracerProvider().Tracer(""), diff --git a/service/kas/access/accessPdp.go b/service/kas/access/accessPdp.go index 8a684d5ace..218ddf39ee 100644 --- a/service/kas/access/accessPdp.go +++ b/service/kas/access/accessPdp.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/service/tracing" @@ -78,10 +77,10 @@ func (p *Provider) checkAttributes(ctx context.Context, ras []*authorization.Res } ctx = tracing.InjectTraceContext(ctx) - dr, err := p.SDK.Authorization.GetDecisionsByToken(ctx, connect.NewRequest(&in)) + dr, err := p.SDK.Authorization.GetDecisionsByToken(ctx, &in) if err != nil { p.Logger.ErrorContext(ctx, "Error received from GetDecisionsByToken", "err", err) return nil, errors.Join(ErrDecisionUnexpected, err) } - return dr.Msg, nil + return dr, nil } diff --git a/service/rttests/rt_test.go b/service/rttests/rt_test.go index ca45939226..64ef90331c 100644 --- a/service/rttests/rt_test.go +++ b/service/rttests/rt_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" "github.com/opentdf/platform/protocol/go/policy/namespaces" @@ -171,12 +170,12 @@ func (s *RoundtripSuite) CreateTestData() error { // create namespace example.com var exampleNamespace *policy.Namespace slog.Info("listing namespaces") - listResp, err := client.Namespaces.ListNamespaces(context.Background(), connect.NewRequest(&namespaces.ListNamespacesRequest{})) + listResp, err := client.Namespaces.ListNamespaces(context.Background(), &namespaces.ListNamespacesRequest{}) if err != nil { return err } - slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.Msg.GetNamespaces()))) - for _, ns := range listResp.Msg.GetNamespaces() { + slog.Info(fmt.Sprintf("found %d namespaces", len(listResp.GetNamespaces()))) + for _, ns := range listResp.GetNamespaces() { slog.Info(fmt.Sprintf("existing namespace; name: %s, id: %s", ns.GetName(), ns.GetId())) if ns.GetName() == "example.com" { exampleNamespace = ns @@ -185,20 +184,20 @@ func (s *RoundtripSuite) CreateTestData() error { if exampleNamespace == nil { slog.Info("creating new namespace") - resp, err := client.Namespaces.CreateNamespace(context.Background(), connect.NewRequest(&namespaces.CreateNamespaceRequest{ + resp, err := client.Namespaces.CreateNamespace(context.Background(), &namespaces.CreateNamespaceRequest{ Name: "example.com", - })) + }) if err != nil { return err } - exampleNamespace = resp.Msg.GetNamespace() + exampleNamespace = resp.GetNamespace() } slog.Info("##################################\n#######################################") // Create the attributes slog.Info("creating attribute language with allOf rule") - _, err = client.Attributes.CreateAttribute(context.Background(), connect.NewRequest(&attributes.CreateAttributeRequest{ + _, err = client.Attributes.CreateAttribute(context.Background(), &attributes.CreateAttributeRequest{ Name: "language", NamespaceId: exampleNamespace.GetId(), Rule: *policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF.Enum(), @@ -207,7 +206,7 @@ func (s *RoundtripSuite) CreateTestData() error { "french", "spanish", }, - })) + }) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("attribute already exists") @@ -220,7 +219,7 @@ func (s *RoundtripSuite) CreateTestData() error { } slog.Info("creating attribute color with anyOf rule") - _, err = client.Attributes.CreateAttribute(context.Background(), connect.NewRequest(&attributes.CreateAttributeRequest{ + _, err = client.Attributes.CreateAttribute(context.Background(), &attributes.CreateAttributeRequest{ Name: "color", NamespaceId: exampleNamespace.GetId(), Rule: *policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF.Enum(), @@ -229,7 +228,7 @@ func (s *RoundtripSuite) CreateTestData() error { "green", "blue", }, - })) + }) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("attribute already exists") @@ -242,7 +241,7 @@ func (s *RoundtripSuite) CreateTestData() error { } slog.Info("creating attribute cards with hierarchy rule") - _, err = client.Attributes.CreateAttribute(context.Background(), connect.NewRequest(&attributes.CreateAttributeRequest{ + _, err = client.Attributes.CreateAttribute(context.Background(), &attributes.CreateAttributeRequest{ Name: "cards", NamespaceId: exampleNamespace.GetId(), Rule: *policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_HIERARCHY.Enum(), @@ -251,7 +250,7 @@ func (s *RoundtripSuite) CreateTestData() error { "queen", "jack", }, - })) + }) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("attribute already exists") @@ -265,33 +264,33 @@ func (s *RoundtripSuite) CreateTestData() error { slog.Info("##################################\n#######################################") - allAttr, err := client.Attributes.ListAttributes(context.Background(), connect.NewRequest(&attributes.ListAttributesRequest{})) + allAttr, err := client.Attributes.ListAttributes(context.Background(), &attributes.ListAttributesRequest{}) if err != nil { slog.Error("could not list attributes", slog.String("error", err.Error())) return err } - slog.Info("list attributes response: " + protojson.Format(allAttr.Msg)) + slog.Info("list attributes response: " + protojson.Format(allAttr)) slog.Info("##################################\n#######################################") // get the attribute ids for the values were mapping to the client var attributeValueIDs []string - fqnResp, err := client.Attributes.GetAttributeValuesByFqns(context.Background(), connect.NewRequest(&attributes.GetAttributeValuesByFqnsRequest{ + fqnResp, err := client.Attributes.GetAttributeValuesByFqns(context.Background(), &attributes.GetAttributeValuesByFqnsRequest{ Fqns: attributesToMap, WithValue: &policy.AttributeValueSelector{}, - })) + }) if err != nil { slog.Error("get attribute values by fqn ", slog.String("error", err.Error())) return err } for _, attribute := range attributesToMap { - attributeValueIDs = append(attributeValueIDs, fqnResp.Msg.GetFqnAttributeValues()[attribute].GetValue().GetId()) + attributeValueIDs = append(attributeValueIDs, fqnResp.GetFqnAttributeValues()[attribute].GetValue().GetId()) } // create subject mappings slog.Info("creating subject mappings for client " + s.TestConfig.ClientID) for _, attributeID := range attributeValueIDs { - _, err = client.SubjectMapping.CreateSubjectMapping(context.Background(), connect.NewRequest(&subjectmapping.CreateSubjectMappingRequest{ + _, err = client.SubjectMapping.CreateSubjectMapping(context.Background(), &subjectmapping.CreateSubjectMappingRequest{ AttributeValueId: attributeID, Actions: []*policy.Action{ {Name: actions.ActionNameCreate}, @@ -311,7 +310,7 @@ func (s *RoundtripSuite) CreateTestData() error { }}, }, }, - })) + }) if err != nil { if returnStatus, ok := status.FromError(err); ok && returnStatus.Code() == codes.AlreadyExists { slog.Info("subject mapping already exists") @@ -324,12 +323,12 @@ func (s *RoundtripSuite) CreateTestData() error { } } - allSubMaps, err := client.SubjectMapping.ListSubjectMappings(context.Background(), connect.NewRequest(&subjectmapping.ListSubjectMappingsRequest{})) + allSubMaps, err := client.SubjectMapping.ListSubjectMappings(context.Background(), &subjectmapping.ListSubjectMappingsRequest{}) if err != nil { slog.Error("could not list subject mappings", slog.String("error", err.Error())) return err } - slog.Info("list subject mappings response: " + protojson.Format(allSubMaps.Msg)) + slog.Info("list subject mappings response: " + protojson.Format(allSubMaps)) return nil } From 1e0558dc563dc3b3e0ef9d489d7653991aa00978 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Thu, 15 May 2025 01:53:13 -0400 Subject: [PATCH 13/31] unit tests for connect interceptors --- sdk/audit/metadata_adding_interceptor_test.go | 149 +++++++++++++---- sdk/auth/token_adding_interceptor_test.go | 155 ++++++++++++------ 2 files changed, 220 insertions(+), 84 deletions(-) diff --git a/sdk/audit/metadata_adding_interceptor_test.go b/sdk/audit/metadata_adding_interceptor_test.go index 8c0b44dda3..687c661245 100644 --- a/sdk/audit/metadata_adding_interceptor_test.go +++ b/sdk/audit/metadata_adding_interceptor_test.go @@ -3,47 +3,52 @@ package audit import ( "context" "net" + "net/http" + "net/http/httptest" "testing" + "connectrpc.com/connect" "github.com/google/uuid" "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) -type FakeAccessServiceServer struct { +type FakeAccessServiceServerConnect struct { requestID uuid.UUID requestIP string actorID string - kas.UnimplementedAccessServiceServer + kasconnect.UnimplementedAccessServiceHandler } -func (f *FakeAccessServiceServer) PublicKey(ctx context.Context, _ *kas.PublicKeyRequest) (*kas.PublicKeyResponse, error) { - if md, ok := metadata.FromIncomingContext(ctx); ok { - requestIDFromMetadata := md.Get(string(RequestIDHeaderKey)) - if len(requestIDFromMetadata) > 0 { - f.requestID, _ = uuid.Parse(requestIDFromMetadata[0]) - } +func (f *FakeAccessServiceServerConnect) PublicKey(ctx context.Context, req *connect.Request[kas.PublicKeyRequest]) (*connect.Response[kas.PublicKeyResponse], error) { + requestIDFromHeader := req.Header().Get(string(RequestIDHeaderKey)) + if requestIDFromHeader != "" { + f.requestID, _ = uuid.Parse(requestIDFromHeader) + } - requestIPFromMetadata := md.Get(string(RequestIPHeaderKey)) - if len(requestIPFromMetadata) > 0 { - f.requestIP = requestIPFromMetadata[0] - } + requestIPFromHeader := req.Header().Get(string(RequestIPHeaderKey)) + if requestIPFromHeader != "" { + f.requestIP = requestIPFromHeader + } - actorIDFromMetadata := md.Get(string(ActorIDHeaderKey)) - if len(actorIDFromMetadata) > 0 { - f.actorID = actorIDFromMetadata[0] - } + actorIDFromHeader := req.Header().Get(string(ActorIDHeaderKey)) + if actorIDFromHeader != "" { + f.actorID = actorIDFromHeader } - return &kas.PublicKeyResponse{}, nil + return connect.NewResponse(&kas.PublicKeyResponse{}), nil } func TestAddingAuditMetadataToOutgoingRequest(t *testing.T) { - server := FakeAccessServiceServer{} - client, stop := runServer(&server) - defer stop() + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} + clientConnect, stopC := runConnectServer(&serverConnect) + defer stopC() + clientGrpc, stopG := runServer(&serverGrpc) + defer stopG() contextRequestID := uuid.New() contextActorID := "actorID" @@ -51,38 +56,110 @@ func TestAddingAuditMetadataToOutgoingRequest(t *testing.T) { ctx = context.WithValue(ctx, RequestIDContextKey, contextRequestID) ctx = context.WithValue(ctx, ActorIDContextKey, contextActorID) - _, err := client.PublicKey(ctx, &kas.PublicKeyRequest{}) + _, err := clientConnect.PublicKey(ctx, connect.NewRequest(&kas.PublicKeyRequest{})) if err != nil { t.Fatalf("error making call: %v", err) } - - if server.requestID != contextRequestID { - t.Fatalf("request ID did not match: %v", server.requestID) + _, err = clientGrpc.PublicKey(ctx, &kas.PublicKeyRequest{}) + if err != nil { + t.Fatalf("error making call: %v", err) } - if server.actorID != contextActorID { - t.Fatalf("actor ID did not match: %v", server.actorID) + for _, ids := range []struct { + actorID string + requestID uuid.UUID + }{ + {requestID: serverConnect.requestID, actorID: serverConnect.actorID}, + {requestID: serverGrpc.requestID, actorID: serverGrpc.actorID}, + } { + if ids.requestID != contextRequestID { + t.Fatalf("request ID did not match: %v", serverConnect.requestID) + } + if ids.requestID != contextRequestID { + t.Fatalf("request ID did not match: %v", serverGrpc.requestID) + } } } func TestIsOKWithNoContextValues(t *testing.T) { - server := FakeAccessServiceServer{} - client, stop := runServer(&server) - defer stop() - - _, err := client.PublicKey(t.Context(), &kas.PublicKeyRequest{}) + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} + clientConnect, stopC := runConnectServer(&serverConnect) + defer stopC() + clientGrpc, stopG := runServer(&serverGrpc) + defer stopG() + + _, err := clientConnect.PublicKey(t.Context(), connect.NewRequest(&kas.PublicKeyRequest{})) + if err != nil { + t.Fatalf("error making call: %v", err) + } + _, err = clientGrpc.PublicKey(t.Context(), &kas.PublicKeyRequest{}) if err != nil { t.Fatalf("error making call: %v", err) } - generatedRequestID, err := uuid.Parse(server.requestID.String()) - if err != nil || generatedRequestID == uuid.Nil { - t.Fatalf("did not generate request ID: %v", err) + for _, ids := range []struct { + actorID string + requestID uuid.UUID + }{ + {requestID: serverConnect.requestID, actorID: serverConnect.actorID}, + {requestID: serverGrpc.requestID, actorID: serverGrpc.actorID}, + } { + generatedRequestIDConnect, err := uuid.Parse(ids.requestID.String()) + if err != nil || generatedRequestIDConnect == uuid.Nil { + t.Fatalf("did not generate request ID: %v", err) + } + + if ids.actorID != "" { + t.Fatalf("actor ID not defaulted correctly: %v", ids.actorID) + } + } +} + +func runConnectServer( + f *FakeAccessServiceServerConnect) (kasconnect.AccessServiceClient, func()) { + mux := http.NewServeMux() + path, handler := kasconnect.NewAccessServiceHandler(f) + mux.Handle(path, handler) + + server := httptest.NewServer(mux) + + client := kasconnect.NewAccessServiceClient( + server.Client(), + server.URL, + connect.WithInterceptors(MetadataAddingConnectInterceptor()), + ) + + return client, func() { + server.Close() } +} + +type FakeAccessServiceServer struct { + requestID uuid.UUID + requestIP string + actorID string + kas.UnimplementedAccessServiceServer +} - if server.actorID != "" { - t.Fatalf("actor ID not defaulted correctly: %v", server.actorID) +func (f *FakeAccessServiceServer) PublicKey(ctx context.Context, _ *kas.PublicKeyRequest) (*kas.PublicKeyResponse, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + requestIDFromMetadata := md.Get(string(RequestIDHeaderKey)) + if len(requestIDFromMetadata) > 0 { + f.requestID, _ = uuid.Parse(requestIDFromMetadata[0]) + } + + requestIPFromMetadata := md.Get(string(RequestIPHeaderKey)) + if len(requestIPFromMetadata) > 0 { + f.requestIP = requestIPFromMetadata[0] + } + + actorIDFromMetadata := md.Get(string(ActorIDHeaderKey)) + if len(actorIDFromMetadata) > 0 { + f.actorID = actorIDFromMetadata[0] + } } + return &kas.PublicKeyResponse{}, nil } func runServer(f *FakeAccessServiceServer) (kas.AccessServiceClient, func()) { diff --git a/sdk/auth/token_adding_interceptor_test.go b/sdk/auth/token_adding_interceptor_test.go index 8d99eed54f..43e1397de5 100644 --- a/sdk/auth/token_adding_interceptor_test.go +++ b/sdk/auth/token_adding_interceptor_test.go @@ -11,13 +11,16 @@ import ( "errors" "net" "net/http" + "net/http/httptest" "testing" + "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/sdk/httputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,73 +45,109 @@ func TestAddingTokensToOutgoingRequest(t *testing.T) { key: key, accessToken: "thisisafakeaccesstoken", } - server := FakeAccessServiceServer{} + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} oo := NewTokenAddingInterceptorWithClient(&ts, httputil.SafeHTTPClientWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, })) - client, stop := runServer(&server, oo) - defer stop() + clientConnect, stopC := runConnectServer(&serverConnect, oo) + defer stopC() - _, err = client.PublicKey(t.Context(), &kas.PublicKeyRequest{}) - require.NoError(t, err, "error making call") - - assert.ElementsMatch(t, server.accessToken, []string{"DPoP thisisafakeaccesstoken"}) - require.Len(t, server.dpopToken, 1, "incorrect dpop token headers") - - dpopToken := server.dpopToken[0] - alg, ok := key.Algorithm().(jwa.SignatureAlgorithm) - assert.True(t, ok, "got a bad signing algorithm") - - _, err = jws.Verify([]byte(dpopToken), jws.WithKey(alg, key)) - require.NoError(t, err, "error verifying signature") - - parsedSignature, _ := jws.Parse([]byte(dpopToken)) - require.Len(t, parsedSignature.Signatures(), 1, "incorrect number of signatures") - - sig := parsedSignature.Signatures()[0] - tokenKey, ok := sig.ProtectedHeaders().Get("jwk") - require.True(t, ok, "didn't get jwk token key") - tkkey, ok := tokenKey.(jwk.Key) - require.True(t, ok, "wrong type for jwk token key", tokenKey) - - tp, _ := tkkey.Thumbprint(crypto.SHA256) - ktp, _ := key.Thumbprint(crypto.SHA256) - assert.Equal(t, tp, ktp, "got the wrong key from the token") - - parsedToken, _ := jwt.Parse([]byte(dpopToken), jwt.WithVerify(false)) + clientGrpc, stopG := runServer(&serverGrpc, oo) + defer stopG() - method, ok := parsedToken.Get("htm") - require.True(t, ok, "error getting htm claim") - assert.Equal(t, http.MethodPost, method, "got a bad method") - - path, ok := parsedToken.Get("htu") - require.True(t, ok, "error getting htu claim") - assert.Equal(t, "/kas.AccessService/PublicKey", path, "got a bad path") - - h := sha256.New() - h.Write([]byte("thisisafakeaccesstoken")) - expectedHash := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h.Sum(nil)) + _, err = clientConnect.PublicKey(context.Background(), connect.NewRequest(&kas.PublicKeyRequest{})) + require.NoError(t, err, "error making call") + _, err = clientGrpc.PublicKey(context.Background(), &kas.PublicKeyRequest{}) + require.NoError(t, err, "error making call") - ath, ok := parsedToken.Get("ath") - require.True(t, ok, "error getting ath claim") - assert.Equal(t, expectedHash, ath, "invalid ath claim in token") + for _, server := range []struct { + accessToken []string + dpopToken []string + }{ + {accessToken: serverConnect.accessToken, dpopToken: serverConnect.dpopToken}, + {accessToken: serverGrpc.accessToken, dpopToken: serverGrpc.dpopToken}, + } { + assert.ElementsMatch(t, server.accessToken, []string{"DPoP thisisafakeaccesstoken"}) + require.Len(t, server.dpopToken, 1, "incorrect dpop token headers") + alg, ok := key.Algorithm().(jwa.SignatureAlgorithm) + assert.True(t, ok, "got a bad signing algorithm") + + dpopToken := server.dpopToken[0] + _, err = jws.Verify([]byte(dpopToken), jws.WithKey(alg, key)) + require.NoError(t, err, "error verifying signature") + + parsedSignature, _ := jws.Parse([]byte(dpopToken)) + require.Len(t, parsedSignature.Signatures(), 1, "incorrect number of signatures") + + sig := parsedSignature.Signatures()[0] + tokenKey, ok := sig.ProtectedHeaders().Get("jwk") + require.True(t, ok, "didn't get jwk token key") + tkkey, ok := tokenKey.(jwk.Key) + require.True(t, ok, "wrong type for jwk token key", tokenKey) + + tp, _ := tkkey.Thumbprint(crypto.SHA256) + ktp, _ := key.Thumbprint(crypto.SHA256) + assert.Equal(t, tp, ktp, "got the wrong key from the token") + + parsedToken, _ := jwt.Parse([]byte(dpopToken), jwt.WithVerify(false)) + + method, ok := parsedToken.Get("htm") + require.True(t, ok, "error getting htm claim") + assert.Equal(t, http.MethodPost, method, "got a bad method") + + path, ok := parsedToken.Get("htu") + require.True(t, ok, "error getting htu claim") + assert.Equal(t, "/kas.AccessService/PublicKey", path, "got a bad path") + + h := sha256.New() + h.Write([]byte("thisisafakeaccesstoken")) + expectedHash := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h.Sum(nil)) + + ath, ok := parsedToken.Get("ath") + require.True(t, ok, "error getting ath claim") + assert.Equal(t, expectedHash, ath, "invalid ath claim in token") + } } func Test_InvalidCredentials_DoesNotSendMessage(t *testing.T) { ts := FakeTokenSource{key: nil, accessToken: ""} - server := FakeAccessServiceServer{} + serverConnect := FakeAccessServiceServerConnect{} + serverGrpc := FakeAccessServiceServer{} oo := NewTokenAddingInterceptorWithClient(&ts, httputil.SafeHTTPClientWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, })) - client, stop := runServer(&server, oo) - defer stop() + clientConnect, stopC := runConnectServer(&serverConnect, oo) + defer stopC() + clientGrpc, stopG := runServer(&serverGrpc, oo) + defer stopG() - _, err := client.PublicKey(t.Context(), &kas.PublicKeyRequest{}) + _, err := clientConnect.PublicKey(context.Background(), connect.NewRequest(&kas.PublicKeyRequest{})) + require.Error(t, err, "should not have sent message because the token source returned an error") + _, err = clientGrpc.PublicKey(context.Background(), &kas.PublicKeyRequest{}) require.Error(t, err, "should not have sent message because the token source returned an error") } +type FakeAccessServiceServerConnect struct { + accessToken []string + dpopToken []string + dpopKey jwk.Key + kasconnect.UnimplementedAccessServiceHandler +} + +func (f *FakeAccessServiceServerConnect) PublicKey(ctx context.Context, req *connect.Request[kas.PublicKeyRequest]) (*connect.Response[kas.PublicKeyResponse], error) { + f.accessToken = []string{req.Header().Get("authorization")} + f.dpopToken = []string{req.Header().Get("dpop")} + var ok bool + f.dpopKey, ok = ctx.Value("dpop-jwk").(jwk.Key) + if !ok { + f.dpopKey = nil + } + return connect.NewResponse(&kas.PublicKeyResponse{}), nil +} + type FakeAccessServiceServer struct { accessToken []string dpopToken []string @@ -156,6 +195,26 @@ func (fts *FakeTokenSource) MakeToken(f func(jwk.Key) ([]byte, error)) ([]byte, return f(fts.key) } +func runConnectServer( + f *FakeAccessServiceServerConnect, oo TokenAddingInterceptor, +) (kasconnect.AccessServiceClient, func()) { + mux := http.NewServeMux() + path, handler := kasconnect.NewAccessServiceHandler(f) + mux.Handle(path, handler) + + server := httptest.NewServer(mux) + + client := kasconnect.NewAccessServiceClient( + server.Client(), + server.URL, + connect.WithInterceptors(oo.AddCredentialsConnect()), + ) + + return client, func() { + server.Close() + } +} + func runServer( //nolint:ireturn // this is pretty concrete f *FakeAccessServiceServer, oo TokenAddingInterceptor, ) (kas.AccessServiceClient, func()) { From 744b6dccd34f5fa964326c3986e78ad55fbb289c Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Thu, 15 May 2025 01:58:04 -0400 Subject: [PATCH 14/31] remove granter test changes --- sdk/granter_test.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sdk/granter_test.go b/sdk/granter_test.go index d9579613e4..f800c6aa9c 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -9,13 +9,11 @@ import ( "strings" "testing" - "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" - "github.com/opentdf/platform/sdk/sdkconnect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" ) const ( @@ -505,16 +503,16 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) { var listAttributeResp attributes.ListAttributesResponse type mockAttributesClient struct { - attributesconnect.AttributesServiceClient + attributes.AttributesServiceClient } -func (*mockAttributesClient) ListAttributes(_ context.Context, _ *connect.Request[attributes.ListAttributesRequest]) (*connect.Response[attributes.ListAttributesResponse], error) { - return connect.NewResponse(&listAttributeResp), nil +func (*mockAttributesClient) ListAttributes(_ context.Context, _ *attributes.ListAttributesRequest, _ ...grpc.CallOption) (*attributes.ListAttributesResponse, error) { + return &listAttributeResp, nil } -func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *connect.Request[attributes.GetAttributeValuesByFqnsRequest]) (*connect.Response[attributes.GetAttributeValuesByFqnsResponse], error) { +func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *attributes.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValuesByFqnsResponse, error) { av := make(map[string]*attributes.GetAttributeValuesByFqnsResponse_AttributeAndValue) - for _, v := range req.Msg.GetFqns() { + for _, v := range req.GetFqns() { vfqn, err := NewAttributeValueFQN(v) if err != nil { return nil, err @@ -526,9 +524,9 @@ func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *co } } - return connect.NewResponse(&attributes.GetAttributeValuesByFqnsResponse{ + return &attributes.GetAttributeValuesByFqnsResponse{ FqnAttributeValues: av, - }), nil + }, nil } // Tests titles are written in the form [{attr}.{value}] => [{resulting kas boolean exp}] @@ -613,7 +611,7 @@ func TestReasonerSpecificity(t *testing.T) { }, } { t.Run(tc.n, func(t *testing.T) { - reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &sdkconnect.AttributesServiceClientConnectWrapper{AttributesServiceClient: &mockAttributesClient{}}, tc.policy...) + reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) require.NoError(t, err) i := 0 plan, err := reasoner.plan(tc.defaults, func() string { @@ -764,7 +762,7 @@ func TestReasonerSpecificityWithNamespaces(t *testing.T) { }, } { t.Run((tc.n + "\n" + tc.desc), func(t *testing.T) { - reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &sdkconnect.AttributesServiceClientConnectWrapper{AttributesServiceClient: &mockAttributesClient{}}, tc.policy...) + reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) require.NoError(t, err) i := 0 plan, err := reasoner.plan(tc.defaults, func() string { From 27ca9ebc365c05ee618c0bf2e564a46da7267b1c Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Thu, 15 May 2025 02:10:11 -0400 Subject: [PATCH 15/31] linting --- sdk/audit/metadata_adding_interceptor_test.go | 5 ++--- sdk/auth/token_adding_interceptor_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sdk/audit/metadata_adding_interceptor_test.go b/sdk/audit/metadata_adding_interceptor_test.go index 687c661245..68756303ab 100644 --- a/sdk/audit/metadata_adding_interceptor_test.go +++ b/sdk/audit/metadata_adding_interceptor_test.go @@ -24,7 +24,7 @@ type FakeAccessServiceServerConnect struct { kasconnect.UnimplementedAccessServiceHandler } -func (f *FakeAccessServiceServerConnect) PublicKey(ctx context.Context, req *connect.Request[kas.PublicKeyRequest]) (*connect.Response[kas.PublicKeyResponse], error) { +func (f *FakeAccessServiceServerConnect) PublicKey(_ context.Context, req *connect.Request[kas.PublicKeyRequest]) (*connect.Response[kas.PublicKeyResponse], error) { requestIDFromHeader := req.Header().Get(string(RequestIDHeaderKey)) if requestIDFromHeader != "" { f.requestID, _ = uuid.Parse(requestIDFromHeader) @@ -116,8 +116,7 @@ func TestIsOKWithNoContextValues(t *testing.T) { } } -func runConnectServer( - f *FakeAccessServiceServerConnect) (kasconnect.AccessServiceClient, func()) { +func runConnectServer(f *FakeAccessServiceServerConnect) (kasconnect.AccessServiceClient, func()) { mux := http.NewServeMux() path, handler := kasconnect.NewAccessServiceHandler(f) mux.Handle(path, handler) diff --git a/sdk/auth/token_adding_interceptor_test.go b/sdk/auth/token_adding_interceptor_test.go index 43e1397de5..a27802e5b8 100644 --- a/sdk/auth/token_adding_interceptor_test.go +++ b/sdk/auth/token_adding_interceptor_test.go @@ -57,9 +57,9 @@ func TestAddingTokensToOutgoingRequest(t *testing.T) { clientGrpc, stopG := runServer(&serverGrpc, oo) defer stopG() - _, err = clientConnect.PublicKey(context.Background(), connect.NewRequest(&kas.PublicKeyRequest{})) + _, err = clientConnect.PublicKey(t.Context(), connect.NewRequest(&kas.PublicKeyRequest{})) require.NoError(t, err, "error making call") - _, err = clientGrpc.PublicKey(context.Background(), &kas.PublicKeyRequest{}) + _, err = clientGrpc.PublicKey(t.Context(), &kas.PublicKeyRequest{}) require.NoError(t, err, "error making call") for _, server := range []struct { @@ -124,9 +124,9 @@ func Test_InvalidCredentials_DoesNotSendMessage(t *testing.T) { clientGrpc, stopG := runServer(&serverGrpc, oo) defer stopG() - _, err := clientConnect.PublicKey(context.Background(), connect.NewRequest(&kas.PublicKeyRequest{})) + _, err := clientConnect.PublicKey(t.Context(), connect.NewRequest(&kas.PublicKeyRequest{})) require.Error(t, err, "should not have sent message because the token source returned an error") - _, err = clientGrpc.PublicKey(context.Background(), &kas.PublicKeyRequest{}) + _, err = clientGrpc.PublicKey(t.Context(), &kas.PublicKeyRequest{}) require.Error(t, err, "should not have sent message because the token source returned an error") } From ac257f37b4f1097733c42196ddc925c3a8a57cc7 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 12:18:30 -0400 Subject: [PATCH 16/31] use sdkconnect interfaces --- sdk/granter.go | 3 +- sdk/granter_test.go | 8 +-- sdk/internal/codegen/runner/generate.go | 18 +++++- sdk/sdk.go | 35 ++++-------- sdk/sdkconnect/actions.go | 19 +++++-- sdk/sdkconnect/attributes.go | 61 ++++++++++++++------- sdk/sdkconnect/authorization.go | 13 +++-- sdk/sdkconnect/entityresolution.go | 10 +++- sdk/sdkconnect/kasregistry.go | 37 +++++++++---- sdk/sdkconnect/keymanagement.go | 19 +++++-- sdk/sdkconnect/namespaces.go | 31 +++++++---- sdk/sdkconnect/registeredresources.go | 37 +++++++++---- sdk/sdkconnect/resourcemapping.go | 37 +++++++++---- sdk/sdkconnect/subjectmapping.go | 40 +++++++++----- sdk/sdkconnect/unsafe.go | 34 ++++++++---- sdk/sdkconnect/wellknownconfiguration.go | 7 ++- sdk/tdf.go | 3 +- service/authorization/authorization_test.go | 26 ++++----- 18 files changed, 283 insertions(+), 155 deletions(-) diff --git a/sdk/granter.go b/sdk/granter.go index 5e0f32c5ab..b45c55b640 100644 --- a/sdk/granter.go +++ b/sdk/granter.go @@ -12,6 +12,7 @@ import ( "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/sdk/sdkconnect" ) var ErrInvalid = errors.New("invalid type") @@ -221,7 +222,7 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant { } // Gets a list of directory of KAS grants for a list of attribute FQNs -func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as attributes.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { +func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { fqnsStr := make([]string, len(fqns)) for i, v := range fqns { fqnsStr[i] = v.String() diff --git a/sdk/granter_test.go b/sdk/granter_test.go index f800c6aa9c..e90b146055 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -11,9 +11,9 @@ import ( "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/policy/attributes" + "github.com/opentdf/platform/sdk/sdkconnect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" ) const ( @@ -503,14 +503,14 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) { var listAttributeResp attributes.ListAttributesResponse type mockAttributesClient struct { - attributes.AttributesServiceClient + sdkconnect.AttributesServiceClient } -func (*mockAttributesClient) ListAttributes(_ context.Context, _ *attributes.ListAttributesRequest, _ ...grpc.CallOption) (*attributes.ListAttributesResponse, error) { +func (*mockAttributesClient) ListAttributes(_ context.Context, _ *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) { return &listAttributeResp, nil } -func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *attributes.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValuesByFqnsResponse, error) { +func (*mockAttributesClient) GetAttributeValuesByFqns(_ context.Context, req *attributes.GetAttributeValuesByFqnsRequest) (*attributes.GetAttributeValuesByFqnsResponse, error) { av := make(map[string]*attributes.GetAttributeValuesByFqnsResponse_AttributeAndValue) for _, v := range req.GetFqns() { vfqn, err := NewAttributeValueFQN(v) diff --git a/sdk/internal/codegen/runner/generate.go b/sdk/internal/codegen/runner/generate.go index e40a593dc3..5feae04e4d 100644 --- a/sdk/internal/codegen/runner/generate.go +++ b/sdk/internal/codegen/runner/generate.go @@ -168,7 +168,6 @@ import ( "context" "%s" "%s" - "google.golang.org/grpc" ) type %sConnectWrapper struct { @@ -192,6 +191,8 @@ func New%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ... connectPackageName, interfaceName) + // Generate the interface type definition + wrapperCode += generateInterfaceType(interfaceName, methods, packageName) // Now generate a wrapper function for each method in the interface for _, method := range methods { wrapperCode += generateWrapperMethod(interfaceName, method, packageName) @@ -201,10 +202,23 @@ func New%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ... return wrapperCode } +func generateInterfaceType(interfaceName string, methods []string, packageName string) string { + // Generate the interface type definition + interfaceType := fmt.Sprintf(` +type %s interface { +`, interfaceName) + for _, method := range methods { + interfaceType += fmt.Sprintf(` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) +`, method, packageName, method, packageName, method) + } + interfaceType += "}\n" + return interfaceType +} + // Generate the wrapper method for a specific method in the interface func generateWrapperMethod(interfaceName, methodName, packageName string) string { return fmt.Sprintf(` -func (w *%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest, _ ...grpc.CallOption) (*%s.%sResponse, error) { +func (w *%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) { // Wrap Connect RPC client request res, err := w.%s.%s(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdk.go b/sdk/sdk.go index 81bf550117..be138137e8 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -15,18 +15,7 @@ import ( "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" - "github.com/opentdf/platform/protocol/go/authorization" - "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" - "github.com/opentdf/platform/protocol/go/policy/actions" - "github.com/opentdf/platform/protocol/go/policy/attributes" - "github.com/opentdf/platform/protocol/go/policy/kasregistry" - "github.com/opentdf/platform/protocol/go/policy/keymanagement" - "github.com/opentdf/platform/protocol/go/policy/namespaces" - "github.com/opentdf/platform/protocol/go/policy/registeredresources" - "github.com/opentdf/platform/protocol/go/policy/resourcemapping" - "github.com/opentdf/platform/protocol/go/policy/subjectmapping" - "github.com/opentdf/platform/protocol/go/policy/unsafe" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" "github.com/opentdf/platform/sdk/audit" @@ -66,18 +55,18 @@ type SDK struct { *collectionStore conn *ConnectRPCConnection tokenSource auth.AccessTokenSource - Actions actions.ActionServiceClient - Attributes attributes.AttributesServiceClient - Authorization authorization.AuthorizationServiceClient - EntityResoution entityresolution.EntityResolutionServiceClient - KeyAccessServerRegistry kasregistry.KeyAccessServerRegistryServiceClient - Namespaces namespaces.NamespaceServiceClient - RegisteredResources registeredresources.RegisteredResourcesServiceClient - ResourceMapping resourcemapping.ResourceMappingServiceClient - SubjectMapping subjectmapping.SubjectMappingServiceClient - Unsafe unsafe.UnsafeServiceClient - KeyManagement keymanagement.KeyManagementServiceClient - wellknownConfiguration wellknownconfiguration.WellKnownServiceClient + Actions sdkconnect.ActionServiceClient + Attributes sdkconnect.AttributesServiceClient + Authorization sdkconnect.AuthorizationServiceClient + EntityResoution sdkconnect.EntityResolutionServiceClient + KeyAccessServerRegistry sdkconnect.KeyAccessServerRegistryServiceClient + Namespaces sdkconnect.NamespaceServiceClient + RegisteredResources sdkconnect.RegisteredResourcesServiceClient + ResourceMapping sdkconnect.ResourceMappingServiceClient + SubjectMapping sdkconnect.SubjectMappingServiceClient + Unsafe sdkconnect.UnsafeServiceClient + KeyManagement sdkconnect.KeyManagementServiceClient + wellknownConfiguration sdkconnect.WellKnownServiceClient } func New(platformEndpoint string, opts ...Option) (*SDK, error) { diff --git a/sdk/sdkconnect/actions.go b/sdk/sdkconnect/actions.go index b60f7a8be9..1a8177c8f7 100644 --- a/sdk/sdkconnect/actions.go +++ b/sdk/sdkconnect/actions.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/actions" "github.com/opentdf/platform/protocol/go/policy/actions/actionsconnect" - "google.golang.org/grpc" ) type ActionServiceClientConnectWrapper struct { @@ -17,7 +16,15 @@ func NewActionServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL return &ActionServiceClientConnectWrapper{ActionServiceClient: actionsconnect.NewActionServiceClient(httpClient, baseURL, opts...)} } -func (w *ActionServiceClientConnectWrapper) GetAction(ctx context.Context, req *actions.GetActionRequest, _ ...grpc.CallOption) (*actions.GetActionResponse, error) { +type ActionServiceClient interface { + GetAction(ctx context.Context, req *actions.GetActionRequest) (*actions.GetActionResponse, error) + ListActions(ctx context.Context, req *actions.ListActionsRequest) (*actions.ListActionsResponse, error) + CreateAction(ctx context.Context, req *actions.CreateActionRequest) (*actions.CreateActionResponse, error) + UpdateAction(ctx context.Context, req *actions.UpdateActionRequest) (*actions.UpdateActionResponse, error) + DeleteAction(ctx context.Context, req *actions.DeleteActionRequest) (*actions.DeleteActionResponse, error) +} + +func (w *ActionServiceClientConnectWrapper) GetAction(ctx context.Context, req *actions.GetActionRequest) (*actions.GetActionResponse, error) { // Wrap Connect RPC client request res, err := w.ActionServiceClient.GetAction(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +33,7 @@ func (w *ActionServiceClientConnectWrapper) GetAction(ctx context.Context, req * return res.Msg, err } -func (w *ActionServiceClientConnectWrapper) ListActions(ctx context.Context, req *actions.ListActionsRequest, _ ...grpc.CallOption) (*actions.ListActionsResponse, error) { +func (w *ActionServiceClientConnectWrapper) ListActions(ctx context.Context, req *actions.ListActionsRequest) (*actions.ListActionsResponse, error) { // Wrap Connect RPC client request res, err := w.ActionServiceClient.ListActions(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +42,7 @@ func (w *ActionServiceClientConnectWrapper) ListActions(ctx context.Context, req return res.Msg, err } -func (w *ActionServiceClientConnectWrapper) CreateAction(ctx context.Context, req *actions.CreateActionRequest, _ ...grpc.CallOption) (*actions.CreateActionResponse, error) { +func (w *ActionServiceClientConnectWrapper) CreateAction(ctx context.Context, req *actions.CreateActionRequest) (*actions.CreateActionResponse, error) { // Wrap Connect RPC client request res, err := w.ActionServiceClient.CreateAction(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +51,7 @@ func (w *ActionServiceClientConnectWrapper) CreateAction(ctx context.Context, re return res.Msg, err } -func (w *ActionServiceClientConnectWrapper) UpdateAction(ctx context.Context, req *actions.UpdateActionRequest, _ ...grpc.CallOption) (*actions.UpdateActionResponse, error) { +func (w *ActionServiceClientConnectWrapper) UpdateAction(ctx context.Context, req *actions.UpdateActionRequest) (*actions.UpdateActionResponse, error) { // Wrap Connect RPC client request res, err := w.ActionServiceClient.UpdateAction(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +60,7 @@ func (w *ActionServiceClientConnectWrapper) UpdateAction(ctx context.Context, re return res.Msg, err } -func (w *ActionServiceClientConnectWrapper) DeleteAction(ctx context.Context, req *actions.DeleteActionRequest, _ ...grpc.CallOption) (*actions.DeleteActionResponse, error) { +func (w *ActionServiceClientConnectWrapper) DeleteAction(ctx context.Context, req *actions.DeleteActionRequest) (*actions.DeleteActionResponse, error) { // Wrap Connect RPC client request res, err := w.ActionServiceClient.DeleteAction(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/attributes.go b/sdk/sdkconnect/attributes.go index 05027dff31..080ef4df83 100644 --- a/sdk/sdkconnect/attributes.go +++ b/sdk/sdkconnect/attributes.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/attributes" "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" - "google.golang.org/grpc" ) type AttributesServiceClientConnectWrapper struct { @@ -17,7 +16,29 @@ func NewAttributesServiceClientConnectWrapper(httpClient connect.HTTPClient, bas return &AttributesServiceClientConnectWrapper{AttributesServiceClient: attributesconnect.NewAttributesServiceClient(httpClient, baseURL, opts...)} } -func (w *AttributesServiceClientConnectWrapper) ListAttributes(ctx context.Context, req *attributes.ListAttributesRequest, _ ...grpc.CallOption) (*attributes.ListAttributesResponse, error) { +type AttributesServiceClient interface { + ListAttributes(ctx context.Context, req *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) + ListAttributeValues(ctx context.Context, req *attributes.ListAttributeValuesRequest) (*attributes.ListAttributeValuesResponse, error) + GetAttribute(ctx context.Context, req *attributes.GetAttributeRequest) (*attributes.GetAttributeResponse, error) + GetAttributeValuesByFqns(ctx context.Context, req *attributes.GetAttributeValuesByFqnsRequest) (*attributes.GetAttributeValuesByFqnsResponse, error) + CreateAttribute(ctx context.Context, req *attributes.CreateAttributeRequest) (*attributes.CreateAttributeResponse, error) + UpdateAttribute(ctx context.Context, req *attributes.UpdateAttributeRequest) (*attributes.UpdateAttributeResponse, error) + DeactivateAttribute(ctx context.Context, req *attributes.DeactivateAttributeRequest) (*attributes.DeactivateAttributeResponse, error) + GetAttributeValue(ctx context.Context, req *attributes.GetAttributeValueRequest) (*attributes.GetAttributeValueResponse, error) + CreateAttributeValue(ctx context.Context, req *attributes.CreateAttributeValueRequest) (*attributes.CreateAttributeValueResponse, error) + UpdateAttributeValue(ctx context.Context, req *attributes.UpdateAttributeValueRequest) (*attributes.UpdateAttributeValueResponse, error) + DeactivateAttributeValue(ctx context.Context, req *attributes.DeactivateAttributeValueRequest) (*attributes.DeactivateAttributeValueResponse, error) + AssignKeyAccessServerToAttribute(ctx context.Context, req *attributes.AssignKeyAccessServerToAttributeRequest) (*attributes.AssignKeyAccessServerToAttributeResponse, error) + RemoveKeyAccessServerFromAttribute(ctx context.Context, req *attributes.RemoveKeyAccessServerFromAttributeRequest) (*attributes.RemoveKeyAccessServerFromAttributeResponse, error) + AssignKeyAccessServerToValue(ctx context.Context, req *attributes.AssignKeyAccessServerToValueRequest) (*attributes.AssignKeyAccessServerToValueResponse, error) + RemoveKeyAccessServerFromValue(ctx context.Context, req *attributes.RemoveKeyAccessServerFromValueRequest) (*attributes.RemoveKeyAccessServerFromValueResponse, error) + AssignPublicKeyToAttribute(ctx context.Context, req *attributes.AssignPublicKeyToAttributeRequest) (*attributes.AssignPublicKeyToAttributeResponse, error) + RemovePublicKeyFromAttribute(ctx context.Context, req *attributes.RemovePublicKeyFromAttributeRequest) (*attributes.RemovePublicKeyFromAttributeResponse, error) + AssignPublicKeyToValue(ctx context.Context, req *attributes.AssignPublicKeyToValueRequest) (*attributes.AssignPublicKeyToValueResponse, error) + RemovePublicKeyFromValue(ctx context.Context, req *attributes.RemovePublicKeyFromValueRequest) (*attributes.RemovePublicKeyFromValueResponse, error) +} + +func (w *AttributesServiceClientConnectWrapper) ListAttributes(ctx context.Context, req *attributes.ListAttributesRequest) (*attributes.ListAttributesResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.ListAttributes(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +47,7 @@ func (w *AttributesServiceClientConnectWrapper) ListAttributes(ctx context.Conte return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) ListAttributeValues(ctx context.Context, req *attributes.ListAttributeValuesRequest, _ ...grpc.CallOption) (*attributes.ListAttributeValuesResponse, error) { +func (w *AttributesServiceClientConnectWrapper) ListAttributeValues(ctx context.Context, req *attributes.ListAttributeValuesRequest) (*attributes.ListAttributeValuesResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.ListAttributeValues(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +56,7 @@ func (w *AttributesServiceClientConnectWrapper) ListAttributeValues(ctx context. return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) GetAttribute(ctx context.Context, req *attributes.GetAttributeRequest, _ ...grpc.CallOption) (*attributes.GetAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) GetAttribute(ctx context.Context, req *attributes.GetAttributeRequest) (*attributes.GetAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.GetAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +65,7 @@ func (w *AttributesServiceClientConnectWrapper) GetAttribute(ctx context.Context return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) GetAttributeValuesByFqns(ctx context.Context, req *attributes.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValuesByFqnsResponse, error) { +func (w *AttributesServiceClientConnectWrapper) GetAttributeValuesByFqns(ctx context.Context, req *attributes.GetAttributeValuesByFqnsRequest) (*attributes.GetAttributeValuesByFqnsResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.GetAttributeValuesByFqns(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +74,7 @@ func (w *AttributesServiceClientConnectWrapper) GetAttributeValuesByFqns(ctx con return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) CreateAttribute(ctx context.Context, req *attributes.CreateAttributeRequest, _ ...grpc.CallOption) (*attributes.CreateAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) CreateAttribute(ctx context.Context, req *attributes.CreateAttributeRequest) (*attributes.CreateAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.CreateAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +83,7 @@ func (w *AttributesServiceClientConnectWrapper) CreateAttribute(ctx context.Cont return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) UpdateAttribute(ctx context.Context, req *attributes.UpdateAttributeRequest, _ ...grpc.CallOption) (*attributes.UpdateAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) UpdateAttribute(ctx context.Context, req *attributes.UpdateAttributeRequest) (*attributes.UpdateAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.UpdateAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +92,7 @@ func (w *AttributesServiceClientConnectWrapper) UpdateAttribute(ctx context.Cont return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) DeactivateAttribute(ctx context.Context, req *attributes.DeactivateAttributeRequest, _ ...grpc.CallOption) (*attributes.DeactivateAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) DeactivateAttribute(ctx context.Context, req *attributes.DeactivateAttributeRequest) (*attributes.DeactivateAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.DeactivateAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +101,7 @@ func (w *AttributesServiceClientConnectWrapper) DeactivateAttribute(ctx context. return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) GetAttributeValue(ctx context.Context, req *attributes.GetAttributeValueRequest, _ ...grpc.CallOption) (*attributes.GetAttributeValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) GetAttributeValue(ctx context.Context, req *attributes.GetAttributeValueRequest) (*attributes.GetAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.GetAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +110,7 @@ func (w *AttributesServiceClientConnectWrapper) GetAttributeValue(ctx context.Co return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) CreateAttributeValue(ctx context.Context, req *attributes.CreateAttributeValueRequest, _ ...grpc.CallOption) (*attributes.CreateAttributeValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) CreateAttributeValue(ctx context.Context, req *attributes.CreateAttributeValueRequest) (*attributes.CreateAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.CreateAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -98,7 +119,7 @@ func (w *AttributesServiceClientConnectWrapper) CreateAttributeValue(ctx context return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) UpdateAttributeValue(ctx context.Context, req *attributes.UpdateAttributeValueRequest, _ ...grpc.CallOption) (*attributes.UpdateAttributeValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) UpdateAttributeValue(ctx context.Context, req *attributes.UpdateAttributeValueRequest) (*attributes.UpdateAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.UpdateAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -107,7 +128,7 @@ func (w *AttributesServiceClientConnectWrapper) UpdateAttributeValue(ctx context return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) DeactivateAttributeValue(ctx context.Context, req *attributes.DeactivateAttributeValueRequest, _ ...grpc.CallOption) (*attributes.DeactivateAttributeValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) DeactivateAttributeValue(ctx context.Context, req *attributes.DeactivateAttributeValueRequest) (*attributes.DeactivateAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.DeactivateAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -116,7 +137,7 @@ func (w *AttributesServiceClientConnectWrapper) DeactivateAttributeValue(ctx con return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToAttribute(ctx context.Context, req *attributes.AssignKeyAccessServerToAttributeRequest, _ ...grpc.CallOption) (*attributes.AssignKeyAccessServerToAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToAttribute(ctx context.Context, req *attributes.AssignKeyAccessServerToAttributeRequest) (*attributes.AssignKeyAccessServerToAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.AssignKeyAccessServerToAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -125,7 +146,7 @@ func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToAttribute return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromAttribute(ctx context.Context, req *attributes.RemoveKeyAccessServerFromAttributeRequest, _ ...grpc.CallOption) (*attributes.RemoveKeyAccessServerFromAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromAttribute(ctx context.Context, req *attributes.RemoveKeyAccessServerFromAttributeRequest) (*attributes.RemoveKeyAccessServerFromAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.RemoveKeyAccessServerFromAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -134,7 +155,7 @@ func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromAttribu return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToValue(ctx context.Context, req *attributes.AssignKeyAccessServerToValueRequest, _ ...grpc.CallOption) (*attributes.AssignKeyAccessServerToValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToValue(ctx context.Context, req *attributes.AssignKeyAccessServerToValueRequest) (*attributes.AssignKeyAccessServerToValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.AssignKeyAccessServerToValue(ctx, connect.NewRequest(req)) if res == nil { @@ -143,7 +164,7 @@ func (w *AttributesServiceClientConnectWrapper) AssignKeyAccessServerToValue(ctx return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromValue(ctx context.Context, req *attributes.RemoveKeyAccessServerFromValueRequest, _ ...grpc.CallOption) (*attributes.RemoveKeyAccessServerFromValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromValue(ctx context.Context, req *attributes.RemoveKeyAccessServerFromValueRequest) (*attributes.RemoveKeyAccessServerFromValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.RemoveKeyAccessServerFromValue(ctx, connect.NewRequest(req)) if res == nil { @@ -152,7 +173,7 @@ func (w *AttributesServiceClientConnectWrapper) RemoveKeyAccessServerFromValue(c return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToAttribute(ctx context.Context, req *attributes.AssignPublicKeyToAttributeRequest, _ ...grpc.CallOption) (*attributes.AssignPublicKeyToAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToAttribute(ctx context.Context, req *attributes.AssignPublicKeyToAttributeRequest) (*attributes.AssignPublicKeyToAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.AssignPublicKeyToAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -161,7 +182,7 @@ func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToAttribute(ctx c return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromAttribute(ctx context.Context, req *attributes.RemovePublicKeyFromAttributeRequest, _ ...grpc.CallOption) (*attributes.RemovePublicKeyFromAttributeResponse, error) { +func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromAttribute(ctx context.Context, req *attributes.RemovePublicKeyFromAttributeRequest) (*attributes.RemovePublicKeyFromAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.RemovePublicKeyFromAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -170,7 +191,7 @@ func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromAttribute(ctx return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToValue(ctx context.Context, req *attributes.AssignPublicKeyToValueRequest, _ ...grpc.CallOption) (*attributes.AssignPublicKeyToValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToValue(ctx context.Context, req *attributes.AssignPublicKeyToValueRequest) (*attributes.AssignPublicKeyToValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.AssignPublicKeyToValue(ctx, connect.NewRequest(req)) if res == nil { @@ -179,7 +200,7 @@ func (w *AttributesServiceClientConnectWrapper) AssignPublicKeyToValue(ctx conte return res.Msg, err } -func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromValue(ctx context.Context, req *attributes.RemovePublicKeyFromValueRequest, _ ...grpc.CallOption) (*attributes.RemovePublicKeyFromValueResponse, error) { +func (w *AttributesServiceClientConnectWrapper) RemovePublicKeyFromValue(ctx context.Context, req *attributes.RemovePublicKeyFromValueRequest) (*attributes.RemovePublicKeyFromValueResponse, error) { // Wrap Connect RPC client request res, err := w.AttributesServiceClient.RemovePublicKeyFromValue(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/authorization.go b/sdk/sdkconnect/authorization.go index 5ef2d6666e..a912aea95e 100644 --- a/sdk/sdkconnect/authorization.go +++ b/sdk/sdkconnect/authorization.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/authorization/authorizationconnect" - "google.golang.org/grpc" ) type AuthorizationServiceClientConnectWrapper struct { @@ -17,7 +16,13 @@ func NewAuthorizationServiceClientConnectWrapper(httpClient connect.HTTPClient, return &AuthorizationServiceClientConnectWrapper{AuthorizationServiceClient: authorizationconnect.NewAuthorizationServiceClient(httpClient, baseURL, opts...)} } -func (w *AuthorizationServiceClientConnectWrapper) GetDecisions(ctx context.Context, req *authorization.GetDecisionsRequest, _ ...grpc.CallOption) (*authorization.GetDecisionsResponse, error) { +type AuthorizationServiceClient interface { + GetDecisions(ctx context.Context, req *authorization.GetDecisionsRequest) (*authorization.GetDecisionsResponse, error) + GetDecisionsByToken(ctx context.Context, req *authorization.GetDecisionsByTokenRequest) (*authorization.GetDecisionsByTokenResponse, error) + GetEntitlements(ctx context.Context, req *authorization.GetEntitlementsRequest) (*authorization.GetEntitlementsResponse, error) +} + +func (w *AuthorizationServiceClientConnectWrapper) GetDecisions(ctx context.Context, req *authorization.GetDecisionsRequest) (*authorization.GetDecisionsResponse, error) { // Wrap Connect RPC client request res, err := w.AuthorizationServiceClient.GetDecisions(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +31,7 @@ func (w *AuthorizationServiceClientConnectWrapper) GetDecisions(ctx context.Cont return res.Msg, err } -func (w *AuthorizationServiceClientConnectWrapper) GetDecisionsByToken(ctx context.Context, req *authorization.GetDecisionsByTokenRequest, _ ...grpc.CallOption) (*authorization.GetDecisionsByTokenResponse, error) { +func (w *AuthorizationServiceClientConnectWrapper) GetDecisionsByToken(ctx context.Context, req *authorization.GetDecisionsByTokenRequest) (*authorization.GetDecisionsByTokenResponse, error) { // Wrap Connect RPC client request res, err := w.AuthorizationServiceClient.GetDecisionsByToken(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +40,7 @@ func (w *AuthorizationServiceClientConnectWrapper) GetDecisionsByToken(ctx conte return res.Msg, err } -func (w *AuthorizationServiceClientConnectWrapper) GetEntitlements(ctx context.Context, req *authorization.GetEntitlementsRequest, _ ...grpc.CallOption) (*authorization.GetEntitlementsResponse, error) { +func (w *AuthorizationServiceClientConnectWrapper) GetEntitlements(ctx context.Context, req *authorization.GetEntitlementsRequest) (*authorization.GetEntitlementsResponse, error) { // Wrap Connect RPC client request res, err := w.AuthorizationServiceClient.GetEntitlements(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/entityresolution.go b/sdk/sdkconnect/entityresolution.go index af3132934f..71c331a477 100644 --- a/sdk/sdkconnect/entityresolution.go +++ b/sdk/sdkconnect/entityresolution.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/entityresolution/entityresolutionconnect" - "google.golang.org/grpc" ) type EntityResolutionServiceClientConnectWrapper struct { @@ -17,7 +16,12 @@ func NewEntityResolutionServiceClientConnectWrapper(httpClient connect.HTTPClien return &EntityResolutionServiceClientConnectWrapper{EntityResolutionServiceClient: entityresolutionconnect.NewEntityResolutionServiceClient(httpClient, baseURL, opts...)} } -func (w *EntityResolutionServiceClientConnectWrapper) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest, _ ...grpc.CallOption) (*entityresolution.ResolveEntitiesResponse, error) { +type EntityResolutionServiceClient interface { + ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) + CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) +} + +func (w *EntityResolutionServiceClientConnectWrapper) ResolveEntities(ctx context.Context, req *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { // Wrap Connect RPC client request res, err := w.EntityResolutionServiceClient.ResolveEntities(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +30,7 @@ func (w *EntityResolutionServiceClientConnectWrapper) ResolveEntities(ctx contex return res.Msg, err } -func (w *EntityResolutionServiceClientConnectWrapper) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest, _ ...grpc.CallOption) (*entityresolution.CreateEntityChainFromJwtResponse, error) { +func (w *EntityResolutionServiceClientConnectWrapper) CreateEntityChainFromJwt(ctx context.Context, req *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { // Wrap Connect RPC client request res, err := w.EntityResolutionServiceClient.CreateEntityChainFromJwt(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/kasregistry.go b/sdk/sdkconnect/kasregistry.go index 8dc14f344a..addd61099a 100644 --- a/sdk/sdkconnect/kasregistry.go +++ b/sdk/sdkconnect/kasregistry.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/kasregistry" "github.com/opentdf/platform/protocol/go/policy/kasregistry/kasregistryconnect" - "google.golang.org/grpc" ) type KeyAccessServerRegistryServiceClientConnectWrapper struct { @@ -17,7 +16,21 @@ func NewKeyAccessServerRegistryServiceClientConnectWrapper(httpClient connect.HT return &KeyAccessServerRegistryServiceClientConnectWrapper{KeyAccessServerRegistryServiceClient: kasregistryconnect.NewKeyAccessServerRegistryServiceClient(httpClient, baseURL, opts...)} } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServers(ctx context.Context, req *kasregistry.ListKeyAccessServersRequest, _ ...grpc.CallOption) (*kasregistry.ListKeyAccessServersResponse, error) { +type KeyAccessServerRegistryServiceClient interface { + ListKeyAccessServers(ctx context.Context, req *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) + GetKeyAccessServer(ctx context.Context, req *kasregistry.GetKeyAccessServerRequest) (*kasregistry.GetKeyAccessServerResponse, error) + CreateKeyAccessServer(ctx context.Context, req *kasregistry.CreateKeyAccessServerRequest) (*kasregistry.CreateKeyAccessServerResponse, error) + UpdateKeyAccessServer(ctx context.Context, req *kasregistry.UpdateKeyAccessServerRequest) (*kasregistry.UpdateKeyAccessServerResponse, error) + DeleteKeyAccessServer(ctx context.Context, req *kasregistry.DeleteKeyAccessServerRequest) (*kasregistry.DeleteKeyAccessServerResponse, error) + ListKeyAccessServerGrants(ctx context.Context, req *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) + CreateKey(ctx context.Context, req *kasregistry.CreateKeyRequest) (*kasregistry.CreateKeyResponse, error) + GetKey(ctx context.Context, req *kasregistry.GetKeyRequest) (*kasregistry.GetKeyResponse, error) + ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest) (*kasregistry.ListKeysResponse, error) + UpdateKey(ctx context.Context, req *kasregistry.UpdateKeyRequest) (*kasregistry.UpdateKeyResponse, error) + RotateKey(ctx context.Context, req *kasregistry.RotateKeyRequest) (*kasregistry.RotateKeyResponse, error) +} + +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServers(ctx context.Context, req *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.ListKeyAccessServers(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +39,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServer return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKeyAccessServer(ctx context.Context, req *kasregistry.GetKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.GetKeyAccessServerResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKeyAccessServer(ctx context.Context, req *kasregistry.GetKeyAccessServerRequest) (*kasregistry.GetKeyAccessServerResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.GetKeyAccessServer(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +48,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKeyAccessServer( return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKeyAccessServer(ctx context.Context, req *kasregistry.CreateKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.CreateKeyAccessServerResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKeyAccessServer(ctx context.Context, req *kasregistry.CreateKeyAccessServerRequest) (*kasregistry.CreateKeyAccessServerResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.CreateKeyAccessServer(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +57,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKeyAccessServ return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKeyAccessServer(ctx context.Context, req *kasregistry.UpdateKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.UpdateKeyAccessServerResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKeyAccessServer(ctx context.Context, req *kasregistry.UpdateKeyAccessServerRequest) (*kasregistry.UpdateKeyAccessServerResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.UpdateKeyAccessServer(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +66,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKeyAccessServ return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) DeleteKeyAccessServer(ctx context.Context, req *kasregistry.DeleteKeyAccessServerRequest, _ ...grpc.CallOption) (*kasregistry.DeleteKeyAccessServerResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) DeleteKeyAccessServer(ctx context.Context, req *kasregistry.DeleteKeyAccessServerRequest) (*kasregistry.DeleteKeyAccessServerResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.DeleteKeyAccessServer(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +75,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) DeleteKeyAccessServ return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServerGrants(ctx context.Context, req *kasregistry.ListKeyAccessServerGrantsRequest, _ ...grpc.CallOption) (*kasregistry.ListKeyAccessServerGrantsResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServerGrants(ctx context.Context, req *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.ListKeyAccessServerGrants(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +84,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeyAccessServer return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKey(ctx context.Context, req *kasregistry.CreateKeyRequest, _ ...grpc.CallOption) (*kasregistry.CreateKeyResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKey(ctx context.Context, req *kasregistry.CreateKeyRequest) (*kasregistry.CreateKeyResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.CreateKey(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +93,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) CreateKey(ctx conte return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKey(ctx context.Context, req *kasregistry.GetKeyRequest, _ ...grpc.CallOption) (*kasregistry.GetKeyResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKey(ctx context.Context, req *kasregistry.GetKeyRequest) (*kasregistry.GetKeyResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.GetKey(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +102,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) GetKey(ctx context. return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest, _ ...grpc.CallOption) (*kasregistry.ListKeysResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest) (*kasregistry.ListKeysResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.ListKeys(ctx, connect.NewRequest(req)) if res == nil { @@ -98,7 +111,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) ListKeys(ctx contex return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKey(ctx context.Context, req *kasregistry.UpdateKeyRequest, _ ...grpc.CallOption) (*kasregistry.UpdateKeyResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKey(ctx context.Context, req *kasregistry.UpdateKeyRequest) (*kasregistry.UpdateKeyResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.UpdateKey(ctx, connect.NewRequest(req)) if res == nil { @@ -107,7 +120,7 @@ func (w *KeyAccessServerRegistryServiceClientConnectWrapper) UpdateKey(ctx conte return res.Msg, err } -func (w *KeyAccessServerRegistryServiceClientConnectWrapper) RotateKey(ctx context.Context, req *kasregistry.RotateKeyRequest, _ ...grpc.CallOption) (*kasregistry.RotateKeyResponse, error) { +func (w *KeyAccessServerRegistryServiceClientConnectWrapper) RotateKey(ctx context.Context, req *kasregistry.RotateKeyRequest) (*kasregistry.RotateKeyResponse, error) { // Wrap Connect RPC client request res, err := w.KeyAccessServerRegistryServiceClient.RotateKey(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/keymanagement.go b/sdk/sdkconnect/keymanagement.go index cb2c2ad8a4..c563ff9144 100644 --- a/sdk/sdkconnect/keymanagement.go +++ b/sdk/sdkconnect/keymanagement.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/keymanagement" "github.com/opentdf/platform/protocol/go/policy/keymanagement/keymanagementconnect" - "google.golang.org/grpc" ) type KeyManagementServiceClientConnectWrapper struct { @@ -17,7 +16,15 @@ func NewKeyManagementServiceClientConnectWrapper(httpClient connect.HTTPClient, return &KeyManagementServiceClientConnectWrapper{KeyManagementServiceClient: keymanagementconnect.NewKeyManagementServiceClient(httpClient, baseURL, opts...)} } -func (w *KeyManagementServiceClientConnectWrapper) CreateProviderConfig(ctx context.Context, req *keymanagement.CreateProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.CreateProviderConfigResponse, error) { +type KeyManagementServiceClient interface { + CreateProviderConfig(ctx context.Context, req *keymanagement.CreateProviderConfigRequest) (*keymanagement.CreateProviderConfigResponse, error) + GetProviderConfig(ctx context.Context, req *keymanagement.GetProviderConfigRequest) (*keymanagement.GetProviderConfigResponse, error) + ListProviderConfigs(ctx context.Context, req *keymanagement.ListProviderConfigsRequest) (*keymanagement.ListProviderConfigsResponse, error) + UpdateProviderConfig(ctx context.Context, req *keymanagement.UpdateProviderConfigRequest) (*keymanagement.UpdateProviderConfigResponse, error) + DeleteProviderConfig(ctx context.Context, req *keymanagement.DeleteProviderConfigRequest) (*keymanagement.DeleteProviderConfigResponse, error) +} + +func (w *KeyManagementServiceClientConnectWrapper) CreateProviderConfig(ctx context.Context, req *keymanagement.CreateProviderConfigRequest) (*keymanagement.CreateProviderConfigResponse, error) { // Wrap Connect RPC client request res, err := w.KeyManagementServiceClient.CreateProviderConfig(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +33,7 @@ func (w *KeyManagementServiceClientConnectWrapper) CreateProviderConfig(ctx cont return res.Msg, err } -func (w *KeyManagementServiceClientConnectWrapper) GetProviderConfig(ctx context.Context, req *keymanagement.GetProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.GetProviderConfigResponse, error) { +func (w *KeyManagementServiceClientConnectWrapper) GetProviderConfig(ctx context.Context, req *keymanagement.GetProviderConfigRequest) (*keymanagement.GetProviderConfigResponse, error) { // Wrap Connect RPC client request res, err := w.KeyManagementServiceClient.GetProviderConfig(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +42,7 @@ func (w *KeyManagementServiceClientConnectWrapper) GetProviderConfig(ctx context return res.Msg, err } -func (w *KeyManagementServiceClientConnectWrapper) ListProviderConfigs(ctx context.Context, req *keymanagement.ListProviderConfigsRequest, _ ...grpc.CallOption) (*keymanagement.ListProviderConfigsResponse, error) { +func (w *KeyManagementServiceClientConnectWrapper) ListProviderConfigs(ctx context.Context, req *keymanagement.ListProviderConfigsRequest) (*keymanagement.ListProviderConfigsResponse, error) { // Wrap Connect RPC client request res, err := w.KeyManagementServiceClient.ListProviderConfigs(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +51,7 @@ func (w *KeyManagementServiceClientConnectWrapper) ListProviderConfigs(ctx conte return res.Msg, err } -func (w *KeyManagementServiceClientConnectWrapper) UpdateProviderConfig(ctx context.Context, req *keymanagement.UpdateProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.UpdateProviderConfigResponse, error) { +func (w *KeyManagementServiceClientConnectWrapper) UpdateProviderConfig(ctx context.Context, req *keymanagement.UpdateProviderConfigRequest) (*keymanagement.UpdateProviderConfigResponse, error) { // Wrap Connect RPC client request res, err := w.KeyManagementServiceClient.UpdateProviderConfig(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +60,7 @@ func (w *KeyManagementServiceClientConnectWrapper) UpdateProviderConfig(ctx cont return res.Msg, err } -func (w *KeyManagementServiceClientConnectWrapper) DeleteProviderConfig(ctx context.Context, req *keymanagement.DeleteProviderConfigRequest, _ ...grpc.CallOption) (*keymanagement.DeleteProviderConfigResponse, error) { +func (w *KeyManagementServiceClientConnectWrapper) DeleteProviderConfig(ctx context.Context, req *keymanagement.DeleteProviderConfigRequest) (*keymanagement.DeleteProviderConfigResponse, error) { // Wrap Connect RPC client request res, err := w.KeyManagementServiceClient.DeleteProviderConfig(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/namespaces.go b/sdk/sdkconnect/namespaces.go index 1b01938c9c..b1d02386b8 100644 --- a/sdk/sdkconnect/namespaces.go +++ b/sdk/sdkconnect/namespaces.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/namespaces" "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" - "google.golang.org/grpc" ) type NamespaceServiceClientConnectWrapper struct { @@ -17,7 +16,19 @@ func NewNamespaceServiceClientConnectWrapper(httpClient connect.HTTPClient, base return &NamespaceServiceClientConnectWrapper{NamespaceServiceClient: namespacesconnect.NewNamespaceServiceClient(httpClient, baseURL, opts...)} } -func (w *NamespaceServiceClientConnectWrapper) GetNamespace(ctx context.Context, req *namespaces.GetNamespaceRequest, _ ...grpc.CallOption) (*namespaces.GetNamespaceResponse, error) { +type NamespaceServiceClient interface { + GetNamespace(ctx context.Context, req *namespaces.GetNamespaceRequest) (*namespaces.GetNamespaceResponse, error) + ListNamespaces(ctx context.Context, req *namespaces.ListNamespacesRequest) (*namespaces.ListNamespacesResponse, error) + CreateNamespace(ctx context.Context, req *namespaces.CreateNamespaceRequest) (*namespaces.CreateNamespaceResponse, error) + UpdateNamespace(ctx context.Context, req *namespaces.UpdateNamespaceRequest) (*namespaces.UpdateNamespaceResponse, error) + DeactivateNamespace(ctx context.Context, req *namespaces.DeactivateNamespaceRequest) (*namespaces.DeactivateNamespaceResponse, error) + AssignKeyAccessServerToNamespace(ctx context.Context, req *namespaces.AssignKeyAccessServerToNamespaceRequest) (*namespaces.AssignKeyAccessServerToNamespaceResponse, error) + RemoveKeyAccessServerFromNamespace(ctx context.Context, req *namespaces.RemoveKeyAccessServerFromNamespaceRequest) (*namespaces.RemoveKeyAccessServerFromNamespaceResponse, error) + AssignPublicKeyToNamespace(ctx context.Context, req *namespaces.AssignPublicKeyToNamespaceRequest) (*namespaces.AssignPublicKeyToNamespaceResponse, error) + RemovePublicKeyFromNamespace(ctx context.Context, req *namespaces.RemovePublicKeyFromNamespaceRequest) (*namespaces.RemovePublicKeyFromNamespaceResponse, error) +} + +func (w *NamespaceServiceClientConnectWrapper) GetNamespace(ctx context.Context, req *namespaces.GetNamespaceRequest) (*namespaces.GetNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.GetNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +37,7 @@ func (w *NamespaceServiceClientConnectWrapper) GetNamespace(ctx context.Context, return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) ListNamespaces(ctx context.Context, req *namespaces.ListNamespacesRequest, _ ...grpc.CallOption) (*namespaces.ListNamespacesResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) ListNamespaces(ctx context.Context, req *namespaces.ListNamespacesRequest) (*namespaces.ListNamespacesResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.ListNamespaces(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +46,7 @@ func (w *NamespaceServiceClientConnectWrapper) ListNamespaces(ctx context.Contex return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) CreateNamespace(ctx context.Context, req *namespaces.CreateNamespaceRequest, _ ...grpc.CallOption) (*namespaces.CreateNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) CreateNamespace(ctx context.Context, req *namespaces.CreateNamespaceRequest) (*namespaces.CreateNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.CreateNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +55,7 @@ func (w *NamespaceServiceClientConnectWrapper) CreateNamespace(ctx context.Conte return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) UpdateNamespace(ctx context.Context, req *namespaces.UpdateNamespaceRequest, _ ...grpc.CallOption) (*namespaces.UpdateNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) UpdateNamespace(ctx context.Context, req *namespaces.UpdateNamespaceRequest) (*namespaces.UpdateNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.UpdateNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +64,7 @@ func (w *NamespaceServiceClientConnectWrapper) UpdateNamespace(ctx context.Conte return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) DeactivateNamespace(ctx context.Context, req *namespaces.DeactivateNamespaceRequest, _ ...grpc.CallOption) (*namespaces.DeactivateNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) DeactivateNamespace(ctx context.Context, req *namespaces.DeactivateNamespaceRequest) (*namespaces.DeactivateNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.DeactivateNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +73,7 @@ func (w *NamespaceServiceClientConnectWrapper) DeactivateNamespace(ctx context.C return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) AssignKeyAccessServerToNamespace(ctx context.Context, req *namespaces.AssignKeyAccessServerToNamespaceRequest, _ ...grpc.CallOption) (*namespaces.AssignKeyAccessServerToNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) AssignKeyAccessServerToNamespace(ctx context.Context, req *namespaces.AssignKeyAccessServerToNamespaceRequest) (*namespaces.AssignKeyAccessServerToNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.AssignKeyAccessServerToNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +82,7 @@ func (w *NamespaceServiceClientConnectWrapper) AssignKeyAccessServerToNamespace( return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) RemoveKeyAccessServerFromNamespace(ctx context.Context, req *namespaces.RemoveKeyAccessServerFromNamespaceRequest, _ ...grpc.CallOption) (*namespaces.RemoveKeyAccessServerFromNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) RemoveKeyAccessServerFromNamespace(ctx context.Context, req *namespaces.RemoveKeyAccessServerFromNamespaceRequest) (*namespaces.RemoveKeyAccessServerFromNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.RemoveKeyAccessServerFromNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +91,7 @@ func (w *NamespaceServiceClientConnectWrapper) RemoveKeyAccessServerFromNamespac return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) AssignPublicKeyToNamespace(ctx context.Context, req *namespaces.AssignPublicKeyToNamespaceRequest, _ ...grpc.CallOption) (*namespaces.AssignPublicKeyToNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) AssignPublicKeyToNamespace(ctx context.Context, req *namespaces.AssignPublicKeyToNamespaceRequest) (*namespaces.AssignPublicKeyToNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.AssignPublicKeyToNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +100,7 @@ func (w *NamespaceServiceClientConnectWrapper) AssignPublicKeyToNamespace(ctx co return res.Msg, err } -func (w *NamespaceServiceClientConnectWrapper) RemovePublicKeyFromNamespace(ctx context.Context, req *namespaces.RemovePublicKeyFromNamespaceRequest, _ ...grpc.CallOption) (*namespaces.RemovePublicKeyFromNamespaceResponse, error) { +func (w *NamespaceServiceClientConnectWrapper) RemovePublicKeyFromNamespace(ctx context.Context, req *namespaces.RemovePublicKeyFromNamespaceRequest) (*namespaces.RemovePublicKeyFromNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.NamespaceServiceClient.RemovePublicKeyFromNamespace(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/registeredresources.go b/sdk/sdkconnect/registeredresources.go index 7974522ede..4f8274d568 100644 --- a/sdk/sdkconnect/registeredresources.go +++ b/sdk/sdkconnect/registeredresources.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/registeredresources" "github.com/opentdf/platform/protocol/go/policy/registeredresources/registeredresourcesconnect" - "google.golang.org/grpc" ) type RegisteredResourcesServiceClientConnectWrapper struct { @@ -17,7 +16,21 @@ func NewRegisteredResourcesServiceClientConnectWrapper(httpClient connect.HTTPCl return &RegisteredResourcesServiceClientConnectWrapper{RegisteredResourcesServiceClient: registeredresourcesconnect.NewRegisteredResourcesServiceClient(httpClient, baseURL, opts...)} } -func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResource(ctx context.Context, req *registeredresources.CreateRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.CreateRegisteredResourceResponse, error) { +type RegisteredResourcesServiceClient interface { + CreateRegisteredResource(ctx context.Context, req *registeredresources.CreateRegisteredResourceRequest) (*registeredresources.CreateRegisteredResourceResponse, error) + GetRegisteredResource(ctx context.Context, req *registeredresources.GetRegisteredResourceRequest) (*registeredresources.GetRegisteredResourceResponse, error) + ListRegisteredResources(ctx context.Context, req *registeredresources.ListRegisteredResourcesRequest) (*registeredresources.ListRegisteredResourcesResponse, error) + UpdateRegisteredResource(ctx context.Context, req *registeredresources.UpdateRegisteredResourceRequest) (*registeredresources.UpdateRegisteredResourceResponse, error) + DeleteRegisteredResource(ctx context.Context, req *registeredresources.DeleteRegisteredResourceRequest) (*registeredresources.DeleteRegisteredResourceResponse, error) + CreateRegisteredResourceValue(ctx context.Context, req *registeredresources.CreateRegisteredResourceValueRequest) (*registeredresources.CreateRegisteredResourceValueResponse, error) + GetRegisteredResourceValue(ctx context.Context, req *registeredresources.GetRegisteredResourceValueRequest) (*registeredresources.GetRegisteredResourceValueResponse, error) + GetRegisteredResourceValuesByFQNs(ctx context.Context, req *registeredresources.GetRegisteredResourceValuesByFQNsRequest) (*registeredresources.GetRegisteredResourceValuesByFQNsResponse, error) + ListRegisteredResourceValues(ctx context.Context, req *registeredresources.ListRegisteredResourceValuesRequest) (*registeredresources.ListRegisteredResourceValuesResponse, error) + UpdateRegisteredResourceValue(ctx context.Context, req *registeredresources.UpdateRegisteredResourceValueRequest) (*registeredresources.UpdateRegisteredResourceValueResponse, error) + DeleteRegisteredResourceValue(ctx context.Context, req *registeredresources.DeleteRegisteredResourceValueRequest) (*registeredresources.DeleteRegisteredResourceValueResponse, error) +} + +func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResource(ctx context.Context, req *registeredresources.CreateRegisteredResourceRequest) (*registeredresources.CreateRegisteredResourceResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.CreateRegisteredResource(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +39,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResourc return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResource(ctx context.Context, req *registeredresources.GetRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.GetRegisteredResourceResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResource(ctx context.Context, req *registeredresources.GetRegisteredResourceRequest) (*registeredresources.GetRegisteredResourceResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.GetRegisteredResource(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +48,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResource(c return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResources(ctx context.Context, req *registeredresources.ListRegisteredResourcesRequest, _ ...grpc.CallOption) (*registeredresources.ListRegisteredResourcesResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResources(ctx context.Context, req *registeredresources.ListRegisteredResourcesRequest) (*registeredresources.ListRegisteredResourcesResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.ListRegisteredResources(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +57,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResources return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResource(ctx context.Context, req *registeredresources.UpdateRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.UpdateRegisteredResourceResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResource(ctx context.Context, req *registeredresources.UpdateRegisteredResourceRequest) (*registeredresources.UpdateRegisteredResourceResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.UpdateRegisteredResource(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +66,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResourc return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResource(ctx context.Context, req *registeredresources.DeleteRegisteredResourceRequest, _ ...grpc.CallOption) (*registeredresources.DeleteRegisteredResourceResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResource(ctx context.Context, req *registeredresources.DeleteRegisteredResourceRequest) (*registeredresources.DeleteRegisteredResourceResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.DeleteRegisteredResource(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +75,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResourc return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResourceValue(ctx context.Context, req *registeredresources.CreateRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.CreateRegisteredResourceValueResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResourceValue(ctx context.Context, req *registeredresources.CreateRegisteredResourceValueRequest) (*registeredresources.CreateRegisteredResourceValueResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.CreateRegisteredResourceValue(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +84,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) CreateRegisteredResourc return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValue(ctx context.Context, req *registeredresources.GetRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.GetRegisteredResourceValueResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValue(ctx context.Context, req *registeredresources.GetRegisteredResourceValueRequest) (*registeredresources.GetRegisteredResourceValueResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.GetRegisteredResourceValue(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +93,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceVa return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValuesByFQNs(ctx context.Context, req *registeredresources.GetRegisteredResourceValuesByFQNsRequest, _ ...grpc.CallOption) (*registeredresources.GetRegisteredResourceValuesByFQNsResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceValuesByFQNs(ctx context.Context, req *registeredresources.GetRegisteredResourceValuesByFQNsRequest) (*registeredresources.GetRegisteredResourceValuesByFQNsResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.GetRegisteredResourceValuesByFQNs(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +102,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) GetRegisteredResourceVa return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResourceValues(ctx context.Context, req *registeredresources.ListRegisteredResourceValuesRequest, _ ...grpc.CallOption) (*registeredresources.ListRegisteredResourceValuesResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResourceValues(ctx context.Context, req *registeredresources.ListRegisteredResourceValuesRequest) (*registeredresources.ListRegisteredResourceValuesResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.ListRegisteredResourceValues(ctx, connect.NewRequest(req)) if res == nil { @@ -98,7 +111,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) ListRegisteredResourceV return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResourceValue(ctx context.Context, req *registeredresources.UpdateRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.UpdateRegisteredResourceValueResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResourceValue(ctx context.Context, req *registeredresources.UpdateRegisteredResourceValueRequest) (*registeredresources.UpdateRegisteredResourceValueResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.UpdateRegisteredResourceValue(ctx, connect.NewRequest(req)) if res == nil { @@ -107,7 +120,7 @@ func (w *RegisteredResourcesServiceClientConnectWrapper) UpdateRegisteredResourc return res.Msg, err } -func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResourceValue(ctx context.Context, req *registeredresources.DeleteRegisteredResourceValueRequest, _ ...grpc.CallOption) (*registeredresources.DeleteRegisteredResourceValueResponse, error) { +func (w *RegisteredResourcesServiceClientConnectWrapper) DeleteRegisteredResourceValue(ctx context.Context, req *registeredresources.DeleteRegisteredResourceValueRequest) (*registeredresources.DeleteRegisteredResourceValueResponse, error) { // Wrap Connect RPC client request res, err := w.RegisteredResourcesServiceClient.DeleteRegisteredResourceValue(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/resourcemapping.go b/sdk/sdkconnect/resourcemapping.go index 17b04ce236..047168bf48 100644 --- a/sdk/sdkconnect/resourcemapping.go +++ b/sdk/sdkconnect/resourcemapping.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/resourcemapping" "github.com/opentdf/platform/protocol/go/policy/resourcemapping/resourcemappingconnect" - "google.golang.org/grpc" ) type ResourceMappingServiceClientConnectWrapper struct { @@ -17,7 +16,21 @@ func NewResourceMappingServiceClientConnectWrapper(httpClient connect.HTTPClient return &ResourceMappingServiceClientConnectWrapper{ResourceMappingServiceClient: resourcemappingconnect.NewResourceMappingServiceClient(httpClient, baseURL, opts...)} } -func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingGroups(ctx context.Context, req *resourcemapping.ListResourceMappingGroupsRequest, _ ...grpc.CallOption) (*resourcemapping.ListResourceMappingGroupsResponse, error) { +type ResourceMappingServiceClient interface { + ListResourceMappingGroups(ctx context.Context, req *resourcemapping.ListResourceMappingGroupsRequest) (*resourcemapping.ListResourceMappingGroupsResponse, error) + GetResourceMappingGroup(ctx context.Context, req *resourcemapping.GetResourceMappingGroupRequest) (*resourcemapping.GetResourceMappingGroupResponse, error) + CreateResourceMappingGroup(ctx context.Context, req *resourcemapping.CreateResourceMappingGroupRequest) (*resourcemapping.CreateResourceMappingGroupResponse, error) + UpdateResourceMappingGroup(ctx context.Context, req *resourcemapping.UpdateResourceMappingGroupRequest) (*resourcemapping.UpdateResourceMappingGroupResponse, error) + DeleteResourceMappingGroup(ctx context.Context, req *resourcemapping.DeleteResourceMappingGroupRequest) (*resourcemapping.DeleteResourceMappingGroupResponse, error) + ListResourceMappings(ctx context.Context, req *resourcemapping.ListResourceMappingsRequest) (*resourcemapping.ListResourceMappingsResponse, error) + ListResourceMappingsByGroupFqns(ctx context.Context, req *resourcemapping.ListResourceMappingsByGroupFqnsRequest) (*resourcemapping.ListResourceMappingsByGroupFqnsResponse, error) + GetResourceMapping(ctx context.Context, req *resourcemapping.GetResourceMappingRequest) (*resourcemapping.GetResourceMappingResponse, error) + CreateResourceMapping(ctx context.Context, req *resourcemapping.CreateResourceMappingRequest) (*resourcemapping.CreateResourceMappingResponse, error) + UpdateResourceMapping(ctx context.Context, req *resourcemapping.UpdateResourceMappingRequest) (*resourcemapping.UpdateResourceMappingResponse, error) + DeleteResourceMapping(ctx context.Context, req *resourcemapping.DeleteResourceMappingRequest) (*resourcemapping.DeleteResourceMappingResponse, error) +} + +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingGroups(ctx context.Context, req *resourcemapping.ListResourceMappingGroupsRequest) (*resourcemapping.ListResourceMappingGroupsResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.ListResourceMappingGroups(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +39,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingGroups(c return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMappingGroup(ctx context.Context, req *resourcemapping.GetResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.GetResourceMappingGroupResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMappingGroup(ctx context.Context, req *resourcemapping.GetResourceMappingGroupRequest) (*resourcemapping.GetResourceMappingGroupResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.GetResourceMappingGroup(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +48,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMappingGroup(ctx return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMappingGroup(ctx context.Context, req *resourcemapping.CreateResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.CreateResourceMappingGroupResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMappingGroup(ctx context.Context, req *resourcemapping.CreateResourceMappingGroupRequest) (*resourcemapping.CreateResourceMappingGroupResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.CreateResourceMappingGroup(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +57,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMappingGroup( return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMappingGroup(ctx context.Context, req *resourcemapping.UpdateResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.UpdateResourceMappingGroupResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMappingGroup(ctx context.Context, req *resourcemapping.UpdateResourceMappingGroupRequest) (*resourcemapping.UpdateResourceMappingGroupResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.UpdateResourceMappingGroup(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +66,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMappingGroup( return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMappingGroup(ctx context.Context, req *resourcemapping.DeleteResourceMappingGroupRequest, _ ...grpc.CallOption) (*resourcemapping.DeleteResourceMappingGroupResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMappingGroup(ctx context.Context, req *resourcemapping.DeleteResourceMappingGroupRequest) (*resourcemapping.DeleteResourceMappingGroupResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.DeleteResourceMappingGroup(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +75,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMappingGroup( return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappings(ctx context.Context, req *resourcemapping.ListResourceMappingsRequest, _ ...grpc.CallOption) (*resourcemapping.ListResourceMappingsResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappings(ctx context.Context, req *resourcemapping.ListResourceMappingsRequest) (*resourcemapping.ListResourceMappingsResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.ListResourceMappings(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +84,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappings(ctx co return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingsByGroupFqns(ctx context.Context, req *resourcemapping.ListResourceMappingsByGroupFqnsRequest, _ ...grpc.CallOption) (*resourcemapping.ListResourceMappingsByGroupFqnsResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingsByGroupFqns(ctx context.Context, req *resourcemapping.ListResourceMappingsByGroupFqnsRequest) (*resourcemapping.ListResourceMappingsByGroupFqnsResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.ListResourceMappingsByGroupFqns(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +93,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) ListResourceMappingsByGroup return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMapping(ctx context.Context, req *resourcemapping.GetResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.GetResourceMappingResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMapping(ctx context.Context, req *resourcemapping.GetResourceMappingRequest) (*resourcemapping.GetResourceMappingResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.GetResourceMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +102,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) GetResourceMapping(ctx cont return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMapping(ctx context.Context, req *resourcemapping.CreateResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.CreateResourceMappingResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMapping(ctx context.Context, req *resourcemapping.CreateResourceMappingRequest) (*resourcemapping.CreateResourceMappingResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.CreateResourceMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -98,7 +111,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) CreateResourceMapping(ctx c return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMapping(ctx context.Context, req *resourcemapping.UpdateResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.UpdateResourceMappingResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMapping(ctx context.Context, req *resourcemapping.UpdateResourceMappingRequest) (*resourcemapping.UpdateResourceMappingResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.UpdateResourceMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -107,7 +120,7 @@ func (w *ResourceMappingServiceClientConnectWrapper) UpdateResourceMapping(ctx c return res.Msg, err } -func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMapping(ctx context.Context, req *resourcemapping.DeleteResourceMappingRequest, _ ...grpc.CallOption) (*resourcemapping.DeleteResourceMappingResponse, error) { +func (w *ResourceMappingServiceClientConnectWrapper) DeleteResourceMapping(ctx context.Context, req *resourcemapping.DeleteResourceMappingRequest) (*resourcemapping.DeleteResourceMappingResponse, error) { // Wrap Connect RPC client request res, err := w.ResourceMappingServiceClient.DeleteResourceMapping(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/subjectmapping.go b/sdk/sdkconnect/subjectmapping.go index 9345fbf442..90640a1d72 100644 --- a/sdk/sdkconnect/subjectmapping.go +++ b/sdk/sdkconnect/subjectmapping.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/subjectmapping" "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" - "google.golang.org/grpc" ) type SubjectMappingServiceClientConnectWrapper struct { @@ -17,7 +16,22 @@ func NewSubjectMappingServiceClientConnectWrapper(httpClient connect.HTTPClient, return &SubjectMappingServiceClientConnectWrapper{SubjectMappingServiceClient: subjectmappingconnect.NewSubjectMappingServiceClient(httpClient, baseURL, opts...)} } -func (w *SubjectMappingServiceClientConnectWrapper) MatchSubjectMappings(ctx context.Context, req *subjectmapping.MatchSubjectMappingsRequest, _ ...grpc.CallOption) (*subjectmapping.MatchSubjectMappingsResponse, error) { +type SubjectMappingServiceClient interface { + MatchSubjectMappings(ctx context.Context, req *subjectmapping.MatchSubjectMappingsRequest) (*subjectmapping.MatchSubjectMappingsResponse, error) + ListSubjectMappings(ctx context.Context, req *subjectmapping.ListSubjectMappingsRequest) (*subjectmapping.ListSubjectMappingsResponse, error) + GetSubjectMapping(ctx context.Context, req *subjectmapping.GetSubjectMappingRequest) (*subjectmapping.GetSubjectMappingResponse, error) + CreateSubjectMapping(ctx context.Context, req *subjectmapping.CreateSubjectMappingRequest) (*subjectmapping.CreateSubjectMappingResponse, error) + UpdateSubjectMapping(ctx context.Context, req *subjectmapping.UpdateSubjectMappingRequest) (*subjectmapping.UpdateSubjectMappingResponse, error) + DeleteSubjectMapping(ctx context.Context, req *subjectmapping.DeleteSubjectMappingRequest) (*subjectmapping.DeleteSubjectMappingResponse, error) + ListSubjectConditionSets(ctx context.Context, req *subjectmapping.ListSubjectConditionSetsRequest) (*subjectmapping.ListSubjectConditionSetsResponse, error) + GetSubjectConditionSet(ctx context.Context, req *subjectmapping.GetSubjectConditionSetRequest) (*subjectmapping.GetSubjectConditionSetResponse, error) + CreateSubjectConditionSet(ctx context.Context, req *subjectmapping.CreateSubjectConditionSetRequest) (*subjectmapping.CreateSubjectConditionSetResponse, error) + UpdateSubjectConditionSet(ctx context.Context, req *subjectmapping.UpdateSubjectConditionSetRequest) (*subjectmapping.UpdateSubjectConditionSetResponse, error) + DeleteSubjectConditionSet(ctx context.Context, req *subjectmapping.DeleteSubjectConditionSetRequest) (*subjectmapping.DeleteSubjectConditionSetResponse, error) + DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *subjectmapping.DeleteAllUnmappedSubjectConditionSetsRequest) (*subjectmapping.DeleteAllUnmappedSubjectConditionSetsResponse, error) +} + +func (w *SubjectMappingServiceClientConnectWrapper) MatchSubjectMappings(ctx context.Context, req *subjectmapping.MatchSubjectMappingsRequest) (*subjectmapping.MatchSubjectMappingsResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.MatchSubjectMappings(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +40,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) MatchSubjectMappings(ctx con return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectMappings(ctx context.Context, req *subjectmapping.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*subjectmapping.ListSubjectMappingsResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectMappings(ctx context.Context, req *subjectmapping.ListSubjectMappingsRequest) (*subjectmapping.ListSubjectMappingsResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.ListSubjectMappings(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +49,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectMappings(ctx cont return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectMapping(ctx context.Context, req *subjectmapping.GetSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.GetSubjectMappingResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectMapping(ctx context.Context, req *subjectmapping.GetSubjectMappingRequest) (*subjectmapping.GetSubjectMappingResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.GetSubjectMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +58,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectMapping(ctx contex return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectMapping(ctx context.Context, req *subjectmapping.CreateSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.CreateSubjectMappingResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectMapping(ctx context.Context, req *subjectmapping.CreateSubjectMappingRequest) (*subjectmapping.CreateSubjectMappingResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.CreateSubjectMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +67,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectMapping(ctx con return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectMapping(ctx context.Context, req *subjectmapping.UpdateSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.UpdateSubjectMappingResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectMapping(ctx context.Context, req *subjectmapping.UpdateSubjectMappingRequest) (*subjectmapping.UpdateSubjectMappingResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.UpdateSubjectMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +76,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectMapping(ctx con return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectMapping(ctx context.Context, req *subjectmapping.DeleteSubjectMappingRequest, _ ...grpc.CallOption) (*subjectmapping.DeleteSubjectMappingResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectMapping(ctx context.Context, req *subjectmapping.DeleteSubjectMappingRequest) (*subjectmapping.DeleteSubjectMappingResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.DeleteSubjectMapping(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +85,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectMapping(ctx con return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectConditionSets(ctx context.Context, req *subjectmapping.ListSubjectConditionSetsRequest, _ ...grpc.CallOption) (*subjectmapping.ListSubjectConditionSetsResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectConditionSets(ctx context.Context, req *subjectmapping.ListSubjectConditionSetsRequest) (*subjectmapping.ListSubjectConditionSetsResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.ListSubjectConditionSets(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +94,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) ListSubjectConditionSets(ctx return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectConditionSet(ctx context.Context, req *subjectmapping.GetSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.GetSubjectConditionSetResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectConditionSet(ctx context.Context, req *subjectmapping.GetSubjectConditionSetRequest) (*subjectmapping.GetSubjectConditionSetResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.GetSubjectConditionSet(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +103,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) GetSubjectConditionSet(ctx c return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectConditionSet(ctx context.Context, req *subjectmapping.CreateSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.CreateSubjectConditionSetResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectConditionSet(ctx context.Context, req *subjectmapping.CreateSubjectConditionSetRequest) (*subjectmapping.CreateSubjectConditionSetResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.CreateSubjectConditionSet(ctx, connect.NewRequest(req)) if res == nil { @@ -98,7 +112,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) CreateSubjectConditionSet(ct return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectConditionSet(ctx context.Context, req *subjectmapping.UpdateSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.UpdateSubjectConditionSetResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectConditionSet(ctx context.Context, req *subjectmapping.UpdateSubjectConditionSetRequest) (*subjectmapping.UpdateSubjectConditionSetResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.UpdateSubjectConditionSet(ctx, connect.NewRequest(req)) if res == nil { @@ -107,7 +121,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) UpdateSubjectConditionSet(ct return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectConditionSet(ctx context.Context, req *subjectmapping.DeleteSubjectConditionSetRequest, _ ...grpc.CallOption) (*subjectmapping.DeleteSubjectConditionSetResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectConditionSet(ctx context.Context, req *subjectmapping.DeleteSubjectConditionSetRequest) (*subjectmapping.DeleteSubjectConditionSetResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.DeleteSubjectConditionSet(ctx, connect.NewRequest(req)) if res == nil { @@ -116,7 +130,7 @@ func (w *SubjectMappingServiceClientConnectWrapper) DeleteSubjectConditionSet(ct return res.Msg, err } -func (w *SubjectMappingServiceClientConnectWrapper) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *subjectmapping.DeleteAllUnmappedSubjectConditionSetsRequest, _ ...grpc.CallOption) (*subjectmapping.DeleteAllUnmappedSubjectConditionSetsResponse, error) { +func (w *SubjectMappingServiceClientConnectWrapper) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *subjectmapping.DeleteAllUnmappedSubjectConditionSetsRequest) (*subjectmapping.DeleteAllUnmappedSubjectConditionSetsResponse, error) { // Wrap Connect RPC client request res, err := w.SubjectMappingServiceClient.DeleteAllUnmappedSubjectConditionSets(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/unsafe.go b/sdk/sdkconnect/unsafe.go index 792cb80836..b01b4344d5 100644 --- a/sdk/sdkconnect/unsafe.go +++ b/sdk/sdkconnect/unsafe.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/policy/unsafe" "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" - "google.golang.org/grpc" ) type UnsafeServiceClientConnectWrapper struct { @@ -17,7 +16,20 @@ func NewUnsafeServiceClientConnectWrapper(httpClient connect.HTTPClient, baseURL return &UnsafeServiceClientConnectWrapper{UnsafeServiceClient: unsafeconnect.NewUnsafeServiceClient(httpClient, baseURL, opts...)} } -func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateNamespace(ctx context.Context, req *unsafe.UnsafeUpdateNamespaceRequest, _ ...grpc.CallOption) (*unsafe.UnsafeUpdateNamespaceResponse, error) { +type UnsafeServiceClient interface { + UnsafeUpdateNamespace(ctx context.Context, req *unsafe.UnsafeUpdateNamespaceRequest) (*unsafe.UnsafeUpdateNamespaceResponse, error) + UnsafeReactivateNamespace(ctx context.Context, req *unsafe.UnsafeReactivateNamespaceRequest) (*unsafe.UnsafeReactivateNamespaceResponse, error) + UnsafeDeleteNamespace(ctx context.Context, req *unsafe.UnsafeDeleteNamespaceRequest) (*unsafe.UnsafeDeleteNamespaceResponse, error) + UnsafeUpdateAttribute(ctx context.Context, req *unsafe.UnsafeUpdateAttributeRequest) (*unsafe.UnsafeUpdateAttributeResponse, error) + UnsafeReactivateAttribute(ctx context.Context, req *unsafe.UnsafeReactivateAttributeRequest) (*unsafe.UnsafeReactivateAttributeResponse, error) + UnsafeDeleteAttribute(ctx context.Context, req *unsafe.UnsafeDeleteAttributeRequest) (*unsafe.UnsafeDeleteAttributeResponse, error) + UnsafeUpdateAttributeValue(ctx context.Context, req *unsafe.UnsafeUpdateAttributeValueRequest) (*unsafe.UnsafeUpdateAttributeValueResponse, error) + UnsafeReactivateAttributeValue(ctx context.Context, req *unsafe.UnsafeReactivateAttributeValueRequest) (*unsafe.UnsafeReactivateAttributeValueResponse, error) + UnsafeDeleteAttributeValue(ctx context.Context, req *unsafe.UnsafeDeleteAttributeValueRequest) (*unsafe.UnsafeDeleteAttributeValueResponse, error) + UnsafeDeleteKasKey(ctx context.Context, req *unsafe.UnsafeDeleteKasKeyRequest) (*unsafe.UnsafeDeleteKasKeyResponse, error) +} + +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateNamespace(ctx context.Context, req *unsafe.UnsafeUpdateNamespaceRequest) (*unsafe.UnsafeUpdateNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeUpdateNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -26,7 +38,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateNamespace(ctx context.Co return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateNamespace(ctx context.Context, req *unsafe.UnsafeReactivateNamespaceRequest, _ ...grpc.CallOption) (*unsafe.UnsafeReactivateNamespaceResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateNamespace(ctx context.Context, req *unsafe.UnsafeReactivateNamespaceRequest) (*unsafe.UnsafeReactivateNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeReactivateNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -35,7 +47,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateNamespace(ctx contex return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteNamespace(ctx context.Context, req *unsafe.UnsafeDeleteNamespaceRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteNamespaceResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteNamespace(ctx context.Context, req *unsafe.UnsafeDeleteNamespaceRequest) (*unsafe.UnsafeDeleteNamespaceResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeDeleteNamespace(ctx, connect.NewRequest(req)) if res == nil { @@ -44,7 +56,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteNamespace(ctx context.Co return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttribute(ctx context.Context, req *unsafe.UnsafeUpdateAttributeRequest, _ ...grpc.CallOption) (*unsafe.UnsafeUpdateAttributeResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttribute(ctx context.Context, req *unsafe.UnsafeUpdateAttributeRequest) (*unsafe.UnsafeUpdateAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeUpdateAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -53,7 +65,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttribute(ctx context.Co return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttribute(ctx context.Context, req *unsafe.UnsafeReactivateAttributeRequest, _ ...grpc.CallOption) (*unsafe.UnsafeReactivateAttributeResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttribute(ctx context.Context, req *unsafe.UnsafeReactivateAttributeRequest) (*unsafe.UnsafeReactivateAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeReactivateAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -62,7 +74,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttribute(ctx contex return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttribute(ctx context.Context, req *unsafe.UnsafeDeleteAttributeRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteAttributeResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttribute(ctx context.Context, req *unsafe.UnsafeDeleteAttributeRequest) (*unsafe.UnsafeDeleteAttributeResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeDeleteAttribute(ctx, connect.NewRequest(req)) if res == nil { @@ -71,7 +83,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttribute(ctx context.Co return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttributeValue(ctx context.Context, req *unsafe.UnsafeUpdateAttributeValueRequest, _ ...grpc.CallOption) (*unsafe.UnsafeUpdateAttributeValueResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttributeValue(ctx context.Context, req *unsafe.UnsafeUpdateAttributeValueRequest) (*unsafe.UnsafeUpdateAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeUpdateAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -80,7 +92,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeUpdateAttributeValue(ctx conte return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttributeValue(ctx context.Context, req *unsafe.UnsafeReactivateAttributeValueRequest, _ ...grpc.CallOption) (*unsafe.UnsafeReactivateAttributeValueResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttributeValue(ctx context.Context, req *unsafe.UnsafeReactivateAttributeValueRequest) (*unsafe.UnsafeReactivateAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeReactivateAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -89,7 +101,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeReactivateAttributeValue(ctx c return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttributeValue(ctx context.Context, req *unsafe.UnsafeDeleteAttributeValueRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteAttributeValueResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttributeValue(ctx context.Context, req *unsafe.UnsafeDeleteAttributeValueRequest) (*unsafe.UnsafeDeleteAttributeValueResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeDeleteAttributeValue(ctx, connect.NewRequest(req)) if res == nil { @@ -98,7 +110,7 @@ func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteAttributeValue(ctx conte return res.Msg, err } -func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteKasKey(ctx context.Context, req *unsafe.UnsafeDeleteKasKeyRequest, _ ...grpc.CallOption) (*unsafe.UnsafeDeleteKasKeyResponse, error) { +func (w *UnsafeServiceClientConnectWrapper) UnsafeDeleteKasKey(ctx context.Context, req *unsafe.UnsafeDeleteKasKeyRequest) (*unsafe.UnsafeDeleteKasKeyResponse, error) { // Wrap Connect RPC client request res, err := w.UnsafeServiceClient.UnsafeDeleteKasKey(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/sdkconnect/wellknownconfiguration.go b/sdk/sdkconnect/wellknownconfiguration.go index 3b7f4822a8..e635d1e6e0 100644 --- a/sdk/sdkconnect/wellknownconfiguration.go +++ b/sdk/sdkconnect/wellknownconfiguration.go @@ -6,7 +6,6 @@ import ( "context" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" "github.com/opentdf/platform/protocol/go/wellknownconfiguration/wellknownconfigurationconnect" - "google.golang.org/grpc" ) type WellKnownServiceClientConnectWrapper struct { @@ -17,7 +16,11 @@ func NewWellKnownServiceClientConnectWrapper(httpClient connect.HTTPClient, base return &WellKnownServiceClientConnectWrapper{WellKnownServiceClient: wellknownconfigurationconnect.NewWellKnownServiceClient(httpClient, baseURL, opts...)} } -func (w *WellKnownServiceClientConnectWrapper) GetWellKnownConfiguration(ctx context.Context, req *wellknownconfiguration.GetWellKnownConfigurationRequest, _ ...grpc.CallOption) (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) { +type WellKnownServiceClient interface { + GetWellKnownConfiguration(ctx context.Context, req *wellknownconfiguration.GetWellKnownConfigurationRequest) (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) +} + +func (w *WellKnownServiceClientConnectWrapper) GetWellKnownConfiguration(ctx context.Context, req *wellknownconfiguration.GetWellKnownConfigurationRequest) (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) { // Wrap Connect RPC client request res, err := w.WellKnownServiceClient.GetWellKnownConfiguration(ctx, connect.NewRequest(req)) if res == nil { diff --git a/sdk/tdf.go b/sdk/tdf.go index a718ba7272..890ca551a1 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -24,6 +24,7 @@ import ( "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/internal/archive" + "github.com/opentdf/platform/sdk/sdkconnect" "google.golang.org/grpc/codes" ) @@ -656,7 +657,7 @@ func createPolicyObject(attributes []AttributeValueFQN) (PolicyObject, error) { return policyObj, nil } -func allowListFromKASRegistry(ctx context.Context, kasRegistryClient kasregistry.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { +func allowListFromKASRegistry(ctx context.Context, kasRegistryClient sdkconnect.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { kases, err := kasRegistryClient.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) if err != nil { return nil, fmt.Errorf("kasregistry.ListKeyAccessServers failed: %w", err) diff --git a/service/authorization/authorization_test.go b/service/authorization/authorization_test.go index 26e8152cb8..14d0e18849 100644 --- a/service/authorization/authorization_test.go +++ b/service/authorization/authorization_test.go @@ -18,11 +18,11 @@ import ( attr "github.com/opentdf/platform/protocol/go/policy/attributes" sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" otdf "github.com/opentdf/platform/sdk" + "github.com/opentdf/platform/sdk/sdkconnect" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" @@ -47,38 +47,38 @@ var ( ) type myAttributesClient struct { - attr.AttributesServiceClient + sdkconnect.AttributesServiceClient } -func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { +func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { return &listAttributeResp, errListAttributes } -func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest, _ ...grpc.CallOption) (*attr.GetAttributeValuesByFqnsResponse, error) { +func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest) (*attr.GetAttributeValuesByFqnsResponse, error) { return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns } type myERSClient struct { - entityresolution.EntityResolutionServiceClient + sdkconnect.EntityResolutionServiceClient } type mySubjectMappingClient struct { - sm.SubjectMappingServiceClient + sdkconnect.SubjectMappingServiceClient } type paginatedMockSubjectMappingClient struct { - sm.SubjectMappingServiceClient + sdkconnect.SubjectMappingServiceClient } -func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { +func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { return &listSubjectMappings, nil } -func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest, _ ...grpc.CallOption) (*entityresolution.CreateEntityChainFromJwtResponse, error) { +func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { return &createEntityChainResp, nil } -func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest, _ ...grpc.CallOption) (*entityresolution.ResolveEntitiesResponse, error) { +func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { return &resolveEntitiesResp, nil } @@ -87,7 +87,7 @@ var ( smListCallCount = 0 ) -func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { +func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { smListCallCount++ // simulate paginated list and policy LIST behavior if smPaginationOffset > 0 { @@ -104,7 +104,7 @@ func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, } type paginatedMockAttributesClient struct { - attr.AttributesServiceClient + sdkconnect.AttributesServiceClient } var ( @@ -112,7 +112,7 @@ var ( attrListCallCount = 0 ) -func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { +func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { attrListCallCount++ // simulate paginated list and policy LIST behavior if attrPaginationOffset > 0 { From caa697cd84b614e8b2218fe5d218db182e812a60 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 12:35:27 -0400 Subject: [PATCH 17/31] generate ersv2 code --- sdk/internal/codegen/runner/generate.go | 41 +++++++++++++++++-------- sdk/sdkconnect/entityresolutionv2.go | 40 ++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 13 deletions(-) create mode 100644 sdk/sdkconnect/entityresolutionv2.go diff --git a/sdk/internal/codegen/runner/generate.go b/sdk/internal/codegen/runner/generate.go index 5feae04e4d..803ebf5e7e 100644 --- a/sdk/internal/codegen/runner/generate.go +++ b/sdk/internal/codegen/runner/generate.go @@ -15,6 +15,8 @@ import ( type clientsToGenerate struct { grpcClientInterface string + suffix string + packageNameOverride string grpcPackagePath string } @@ -35,6 +37,12 @@ var clientsToGenerateList = []clientsToGenerate{ grpcClientInterface: "EntityResolutionServiceClient", grpcPackagePath: "github.com/opentdf/platform/protocol/go/entityresolution", }, + { + grpcClientInterface: "EntityResolutionServiceClient", + suffix: "V2", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/entityresolution/v2", + packageNameOverride: "entityresolutionv2", + }, { grpcClientInterface: "KeyAccessServerRegistryServiceClient", grpcPackagePath: "github.com/opentdf/platform/protocol/go/policy/kasregistry", @@ -106,7 +114,10 @@ func Generate() error { } if ts.Name.Name == client.grpcClientInterface { packageName := path.Base(client.grpcPackagePath) - code := generateWrapper(ts.Name.Name, iface, client.grpcPackagePath, packageName) + if client.packageNameOverride != "" { + packageName = client.packageNameOverride + } + code := generateWrapper(ts.Name.Name, iface, client.grpcPackagePath, packageName, client.suffix) var currentDir string currentDir, err = getCurrentFileDir() outputPath := filepath.Join(currentDir, "..", "..", "..", "sdkconnect", packageName+".go") @@ -154,7 +165,7 @@ func getMethodNames(interfaceType *ast.InterfaceType) []string { } // Generate wrapper code for the Connect RPC client interface -func generateWrapper(interfaceName string, interfaceType *ast.InterfaceType, packagePath string, packageName string) string { +func generateWrapper(interfaceName string, interfaceType *ast.InterfaceType, packagePath string, packageName string, suffix string) string { // Get method names dynamically from the interface methods := getMethodNames(interfaceType) connectPackageName := packageName + "connect" @@ -170,43 +181,47 @@ import ( "%s" ) -type %sConnectWrapper struct { +type %s%sConnectWrapper struct { %s.%s } -func New%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *%sConnectWrapper { - return &%sConnectWrapper{%s: %s.New%s(httpClient, baseURL, opts...)} +func New%s%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *%s%sConnectWrapper { + return &%s%sConnectWrapper{%s: %s.New%s(httpClient, baseURL, opts...)} } `, interfaceName, packagePath, packagePath+"/"+connectPackageName, interfaceName, + suffix, connectPackageName, interfaceName, interfaceName, + suffix, interfaceName, + suffix, interfaceName, + suffix, interfaceName, connectPackageName, interfaceName) // Generate the interface type definition - wrapperCode += generateInterfaceType(interfaceName, methods, packageName) + wrapperCode += generateInterfaceType(interfaceName, methods, packageName, suffix) // Now generate a wrapper function for each method in the interface for _, method := range methods { - wrapperCode += generateWrapperMethod(interfaceName, method, packageName) + wrapperCode += generateWrapperMethod(interfaceName, method, packageName, suffix) } // Output the generated wrapper code return wrapperCode } -func generateInterfaceType(interfaceName string, methods []string, packageName string) string { +func generateInterfaceType(interfaceName string, methods []string, packageName string, suffix string) string { // Generate the interface type definition interfaceType := fmt.Sprintf(` -type %s interface { -`, interfaceName) +type %s%s interface { +`, interfaceName, suffix) for _, method := range methods { interfaceType += fmt.Sprintf(` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) `, method, packageName, method, packageName, method) @@ -216,9 +231,9 @@ type %s interface { } // Generate the wrapper method for a specific method in the interface -func generateWrapperMethod(interfaceName, methodName, packageName string) string { +func generateWrapperMethod(interfaceName, methodName, packageName string, suffix string) string { return fmt.Sprintf(` -func (w *%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) { +func (w *%s%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) { // Wrap Connect RPC client request res, err := w.%s.%s(ctx, connect.NewRequest(req)) if res == nil { @@ -226,5 +241,5 @@ func (w *%sConnectWrapper) %s(ctx context.Context, req *%s.%sRequest) (*%s.%sRes } return res.Msg, err } -`, interfaceName, methodName, packageName, methodName, packageName, methodName, interfaceName, methodName) +`, interfaceName, suffix, methodName, packageName, methodName, packageName, methodName, interfaceName, methodName) } diff --git a/sdk/sdkconnect/entityresolutionv2.go b/sdk/sdkconnect/entityresolutionv2.go new file mode 100644 index 0000000000..b0173972bc --- /dev/null +++ b/sdk/sdkconnect/entityresolutionv2.go @@ -0,0 +1,40 @@ +// Wrapper for EntityResolutionServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/entityresolution/v2" + "github.com/opentdf/platform/protocol/go/entityresolution/v2/entityresolutionv2connect" +) + +type EntityResolutionServiceClientV2ConnectWrapper struct { + entityresolutionv2connect.EntityResolutionServiceClient +} + +func NewEntityResolutionServiceClientV2ConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *EntityResolutionServiceClientV2ConnectWrapper { + return &EntityResolutionServiceClientV2ConnectWrapper{EntityResolutionServiceClient: entityresolutionv2connect.NewEntityResolutionServiceClient(httpClient, baseURL, opts...)} +} + +type EntityResolutionServiceClientV2 interface { + ResolveEntities(ctx context.Context, req *entityresolutionv2.ResolveEntitiesRequest) (*entityresolutionv2.ResolveEntitiesResponse, error) + CreateEntityChainsFromTokens(ctx context.Context, req *entityresolutionv2.CreateEntityChainsFromTokensRequest) (*entityresolutionv2.CreateEntityChainsFromTokensResponse, error) +} + +func (w *EntityResolutionServiceClientV2ConnectWrapper) ResolveEntities(ctx context.Context, req *entityresolutionv2.ResolveEntitiesRequest) (*entityresolutionv2.ResolveEntitiesResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.ResolveEntities(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *EntityResolutionServiceClientV2ConnectWrapper) CreateEntityChainsFromTokens(ctx context.Context, req *entityresolutionv2.CreateEntityChainsFromTokensRequest) (*entityresolutionv2.CreateEntityChainsFromTokensResponse, error) { + // Wrap Connect RPC client request + res, err := w.EntityResolutionServiceClient.CreateEntityChainsFromTokens(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} From 50faf8dc2079618addcfc7191bdd86ac0a4fe76f Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 13:36:49 -0400 Subject: [PATCH 18/31] move authorization test clients to separate file, implement interface --- service/authorization/authorization_test.go | 104 ------ .../authorization_test_structures.go | 308 ++++++++++++++++++ 2 files changed, 308 insertions(+), 104 deletions(-) create mode 100644 service/authorization/authorization_test_structures.go diff --git a/service/authorization/authorization_test.go b/service/authorization/authorization_test.go index 14d0e18849..082448b0ae 100644 --- a/service/authorization/authorization_test.go +++ b/service/authorization/authorization_test.go @@ -1,9 +1,7 @@ package authorization import ( - "context" "errors" - "fmt" "log/slog" "strings" "testing" @@ -16,9 +14,7 @@ import ( "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" attr "github.com/opentdf/platform/protocol/go/policy/attributes" - sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" otdf "github.com/opentdf/platform/sdk" - "github.com/opentdf/platform/sdk/sdkconnect" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" "github.com/stretchr/testify/assert" @@ -28,106 +24,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -var ( - getAttributesByValueFqnsResponse attr.GetAttributeValuesByFqnsResponse - errGetAttributesByValueFqns error - listAttributeResp attr.ListAttributesResponse - errListAttributes error - listSubjectMappings sm.ListSubjectMappingsResponse - createEntityChainResp entityresolution.CreateEntityChainFromJwtResponse - resolveEntitiesResp entityresolution.ResolveEntitiesResponse - mockNamespace = "www.example.org" - mockAttrName = "foo" - mockAttrValue1 = "value1" - mockAttrValue2 = "value2" - mockAttrValue3 = "value3" - mockFqn1 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue1) - mockFqn2 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue2) - mockFqn3 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue3) -) - -type myAttributesClient struct { - sdkconnect.AttributesServiceClient -} - -func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { - return &listAttributeResp, errListAttributes -} - -func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest) (*attr.GetAttributeValuesByFqnsResponse, error) { - return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns -} - -type myERSClient struct { - sdkconnect.EntityResolutionServiceClient -} - -type mySubjectMappingClient struct { - sdkconnect.SubjectMappingServiceClient -} - -type paginatedMockSubjectMappingClient struct { - sdkconnect.SubjectMappingServiceClient -} - -func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { - return &listSubjectMappings, nil -} - -func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { - return &createEntityChainResp, nil -} - -func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { - return &resolveEntitiesResp, nil -} - -var ( - smPaginationOffset = 3 - smListCallCount = 0 -) - -func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { - smListCallCount++ - // simulate paginated list and policy LIST behavior - if smPaginationOffset > 0 { - rsp := &sm.ListSubjectMappingsResponse{ - SubjectMappings: nil, - Pagination: &policy.PageResponse{ - NextOffset: int32(smPaginationOffset), - }, - } - smPaginationOffset = 0 - return rsp, nil - } - return &listSubjectMappings, nil -} - -type paginatedMockAttributesClient struct { - sdkconnect.AttributesServiceClient -} - -var ( - attrPaginationOffset = 3 - attrListCallCount = 0 -) - -func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { - attrListCallCount++ - // simulate paginated list and policy LIST behavior - if attrPaginationOffset > 0 { - rsp := &attr.ListAttributesResponse{ - Attributes: nil, - Pagination: &policy.PageResponse{ - NextOffset: int32(attrPaginationOffset), - }, - } - attrPaginationOffset = 0 - return rsp, nil - } - return &listAttributeResp, nil -} - func TestGetComprehensiveHierarchy(t *testing.T) { as := &AuthorizationService{ logger: logger.CreateTestLogger(), diff --git a/service/authorization/authorization_test_structures.go b/service/authorization/authorization_test_structures.go new file mode 100644 index 0000000000..0f3759e15b --- /dev/null +++ b/service/authorization/authorization_test_structures.go @@ -0,0 +1,308 @@ +package authorization + +import ( + "context" + "fmt" + + "github.com/opentdf/platform/protocol/go/entityresolution" + "github.com/opentdf/platform/protocol/go/policy" + attr "github.com/opentdf/platform/protocol/go/policy/attributes" + sm "github.com/opentdf/platform/protocol/go/policy/subjectmapping" +) + +var ( + getAttributesByValueFqnsResponse attr.GetAttributeValuesByFqnsResponse + errGetAttributesByValueFqns error + listAttributeResp attr.ListAttributesResponse + errListAttributes error + listSubjectMappings sm.ListSubjectMappingsResponse + createEntityChainResp entityresolution.CreateEntityChainFromJwtResponse + resolveEntitiesResp entityresolution.ResolveEntitiesResponse + mockNamespace = "www.example.org" + mockAttrName = "foo" + mockAttrValue1 = "value1" + mockAttrValue2 = "value2" + mockAttrValue3 = "value3" + mockFqn1 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue1) + mockFqn2 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue2) + mockFqn3 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue3) +) + +////// Mock attributes client for testing ///// + +type myAttributesClient struct { +} + +func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { + return &listAttributeResp, errListAttributes +} + +func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest) (*attr.GetAttributeValuesByFqnsResponse, error) { + return &getAttributesByValueFqnsResponse, errGetAttributesByValueFqns +} + +func (*myAttributesClient) ListAttributeValues(_ context.Context, _ *attr.ListAttributeValuesRequest) (*attr.ListAttributeValuesResponse, error) { + return &attr.ListAttributeValuesResponse{}, nil +} +func (*myAttributesClient) GetAttribute(_ context.Context, _ *attr.GetAttributeRequest) (*attr.GetAttributeResponse, error) { + return &attr.GetAttributeResponse{}, nil +} +func (*myAttributesClient) GetAttributeValue(_ context.Context, _ *attr.GetAttributeValueRequest) (*attr.GetAttributeValueResponse, error) { + return &attr.GetAttributeValueResponse{}, nil +} +func (*myAttributesClient) CreateAttribute(_ context.Context, _ *attr.CreateAttributeRequest) (*attr.CreateAttributeResponse, error) { + return &attr.CreateAttributeResponse{}, nil +} +func (*myAttributesClient) UpdateAttribute(_ context.Context, _ *attr.UpdateAttributeRequest) (*attr.UpdateAttributeResponse, error) { + return &attr.UpdateAttributeResponse{}, nil +} +func (*myAttributesClient) DeactivateAttribute(_ context.Context, _ *attr.DeactivateAttributeRequest) (*attr.DeactivateAttributeResponse, error) { + return &attr.DeactivateAttributeResponse{}, nil +} +func (*myAttributesClient) CreateAttributeValue(_ context.Context, _ *attr.CreateAttributeValueRequest) (*attr.CreateAttributeValueResponse, error) { + return &attr.CreateAttributeValueResponse{}, nil +} +func (*myAttributesClient) UpdateAttributeValue(_ context.Context, _ *attr.UpdateAttributeValueRequest) (*attr.UpdateAttributeValueResponse, error) { + return &attr.UpdateAttributeValueResponse{}, nil +} +func (*myAttributesClient) DeactivateAttributeValue(_ context.Context, _ *attr.DeactivateAttributeValueRequest) (*attr.DeactivateAttributeValueResponse, error) { + return &attr.DeactivateAttributeValueResponse{}, nil +} +func (*myAttributesClient) AssignKeyAccessServerToAttribute(_ context.Context, _ *attr.AssignKeyAccessServerToAttributeRequest) (*attr.AssignKeyAccessServerToAttributeResponse, error) { + return &attr.AssignKeyAccessServerToAttributeResponse{}, nil +} +func (*myAttributesClient) RemoveKeyAccessServerFromAttribute(_ context.Context, _ *attr.RemoveKeyAccessServerFromAttributeRequest) (*attr.RemoveKeyAccessServerFromAttributeResponse, error) { + return &attr.RemoveKeyAccessServerFromAttributeResponse{}, nil +} +func (*myAttributesClient) AssignKeyAccessServerToValue(_ context.Context, _ *attr.AssignKeyAccessServerToValueRequest) (*attr.AssignKeyAccessServerToValueResponse, error) { + return &attr.AssignKeyAccessServerToValueResponse{}, nil +} +func (*myAttributesClient) RemoveKeyAccessServerFromValue(_ context.Context, _ *attr.RemoveKeyAccessServerFromValueRequest) (*attr.RemoveKeyAccessServerFromValueResponse, error) { + return &attr.RemoveKeyAccessServerFromValueResponse{}, nil +} +func (*myAttributesClient) AssignPublicKeyToAttribute(_ context.Context, _ *attr.AssignPublicKeyToAttributeRequest) (*attr.AssignPublicKeyToAttributeResponse, error) { + return &attr.AssignPublicKeyToAttributeResponse{}, nil +} +func (*myAttributesClient) RemovePublicKeyFromAttribute(_ context.Context, _ *attr.RemovePublicKeyFromAttributeRequest) (*attr.RemovePublicKeyFromAttributeResponse, error) { + return &attr.RemovePublicKeyFromAttributeResponse{}, nil +} +func (*myAttributesClient) AssignPublicKeyToValue(_ context.Context, _ *attr.AssignPublicKeyToValueRequest) (*attr.AssignPublicKeyToValueResponse, error) { + return &attr.AssignPublicKeyToValueResponse{}, nil +} +func (*myAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.RemovePublicKeyFromValueRequest) (*attr.RemovePublicKeyFromValueResponse, error) { + return &attr.RemovePublicKeyFromValueResponse{}, nil +} + +// // Mock ERS Client for testing ///// +type myERSClient struct { +} + +func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { + return &createEntityChainResp, nil +} + +func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.ResolveEntitiesRequest) (*entityresolution.ResolveEntitiesResponse, error) { + return &resolveEntitiesResp, nil +} + +// // Mock Subject Mapping Client for testing ///// +type mySubjectMappingClient struct { +} + +func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { + return &listSubjectMappings, nil +} + +func (*mySubjectMappingClient) MatchSubjectMappings(ctx context.Context, req *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { + return &sm.MatchSubjectMappingsResponse{}, nil +} + +func (*mySubjectMappingClient) GetSubjectMapping(ctx context.Context, req *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { + return &sm.GetSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) CreateSubjectMapping(ctx context.Context, req *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { + return &sm.CreateSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) UpdateSubjectMapping(ctx context.Context, req *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { + return &sm.UpdateSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) DeleteSubjectMapping(ctx context.Context, req *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { + return &sm.DeleteSubjectMappingResponse{}, nil +} + +func (*mySubjectMappingClient) ListSubjectConditionSets(ctx context.Context, req *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { + return &sm.ListSubjectConditionSetsResponse{}, nil +} + +func (*mySubjectMappingClient) GetSubjectConditionSet(ctx context.Context, req *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { + return &sm.GetSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) CreateSubjectConditionSet(ctx context.Context, req *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { + return &sm.CreateSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) UpdateSubjectConditionSet(ctx context.Context, req *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { + return &sm.UpdateSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) DeleteSubjectConditionSet(ctx context.Context, req *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { + return &sm.DeleteSubjectConditionSetResponse{}, nil +} + +func (*mySubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { + return &sm.DeleteAllUnmappedSubjectConditionSetsResponse{}, nil +} + +// // Mock paginated Subject Mapping Client for testing ///// +type paginatedMockSubjectMappingClient struct { +} + +var ( + smPaginationOffset = 3 + smListCallCount = 0 +) + +func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { + smListCallCount++ + // simulate paginated list and policy LIST behavior + if smPaginationOffset > 0 { + rsp := &sm.ListSubjectMappingsResponse{ + SubjectMappings: nil, + Pagination: &policy.PageResponse{ + NextOffset: int32(smPaginationOffset), + }, + } + smPaginationOffset = 0 + return rsp, nil + } + return &listSubjectMappings, nil +} + +func (*paginatedMockSubjectMappingClient) MatchSubjectMappings(ctx context.Context, req *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { + return &sm.MatchSubjectMappingsResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) GetSubjectMapping(ctx context.Context, req *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { + return &sm.GetSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) CreateSubjectMapping(ctx context.Context, req *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { + return &sm.CreateSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) UpdateSubjectMapping(ctx context.Context, req *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { + return &sm.UpdateSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) DeleteSubjectMapping(ctx context.Context, req *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { + return &sm.DeleteSubjectMappingResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) ListSubjectConditionSets(ctx context.Context, req *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { + return &sm.ListSubjectConditionSetsResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) GetSubjectConditionSet(ctx context.Context, req *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { + return &sm.GetSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) CreateSubjectConditionSet(ctx context.Context, req *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { + return &sm.CreateSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) UpdateSubjectConditionSet(ctx context.Context, req *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { + return &sm.UpdateSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) DeleteSubjectConditionSet(ctx context.Context, req *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { + return &sm.DeleteSubjectConditionSetResponse{}, nil +} + +func (*paginatedMockSubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { + return &sm.DeleteAllUnmappedSubjectConditionSetsResponse{}, nil +} + +// // Mock paginated attributs client for testing //// +type paginatedMockAttributesClient struct { +} + +var ( + attrPaginationOffset = 3 + attrListCallCount = 0 +) + +func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { + attrListCallCount++ + // simulate paginated list and policy LIST behavior + if attrPaginationOffset > 0 { + rsp := &attr.ListAttributesResponse{ + Attributes: nil, + Pagination: &policy.PageResponse{ + NextOffset: int32(attrPaginationOffset), + }, + } + attrPaginationOffset = 0 + return rsp, nil + } + return &listAttributeResp, nil +} + +func (*paginatedMockAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.GetAttributeValuesByFqnsRequest) (*attr.GetAttributeValuesByFqnsResponse, error) { + return &attr.GetAttributeValuesByFqnsResponse{}, nil +} + +func (*paginatedMockAttributesClient) ListAttributeValues(_ context.Context, _ *attr.ListAttributeValuesRequest) (*attr.ListAttributeValuesResponse, error) { + return &attr.ListAttributeValuesResponse{}, nil +} +func (*paginatedMockAttributesClient) GetAttribute(_ context.Context, _ *attr.GetAttributeRequest) (*attr.GetAttributeResponse, error) { + return &attr.GetAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) GetAttributeValue(_ context.Context, _ *attr.GetAttributeValueRequest) (*attr.GetAttributeValueResponse, error) { + return &attr.GetAttributeValueResponse{}, nil +} +func (*paginatedMockAttributesClient) CreateAttribute(_ context.Context, _ *attr.CreateAttributeRequest) (*attr.CreateAttributeResponse, error) { + return &attr.CreateAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) UpdateAttribute(_ context.Context, _ *attr.UpdateAttributeRequest) (*attr.UpdateAttributeResponse, error) { + return &attr.UpdateAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) DeactivateAttribute(_ context.Context, _ *attr.DeactivateAttributeRequest) (*attr.DeactivateAttributeResponse, error) { + return &attr.DeactivateAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) CreateAttributeValue(_ context.Context, _ *attr.CreateAttributeValueRequest) (*attr.CreateAttributeValueResponse, error) { + return &attr.CreateAttributeValueResponse{}, nil +} +func (*paginatedMockAttributesClient) UpdateAttributeValue(_ context.Context, _ *attr.UpdateAttributeValueRequest) (*attr.UpdateAttributeValueResponse, error) { + return &attr.UpdateAttributeValueResponse{}, nil +} +func (*paginatedMockAttributesClient) DeactivateAttributeValue(_ context.Context, _ *attr.DeactivateAttributeValueRequest) (*attr.DeactivateAttributeValueResponse, error) { + return &attr.DeactivateAttributeValueResponse{}, nil +} +func (*paginatedMockAttributesClient) AssignKeyAccessServerToAttribute(_ context.Context, _ *attr.AssignKeyAccessServerToAttributeRequest) (*attr.AssignKeyAccessServerToAttributeResponse, error) { + return &attr.AssignKeyAccessServerToAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) RemoveKeyAccessServerFromAttribute(_ context.Context, _ *attr.RemoveKeyAccessServerFromAttributeRequest) (*attr.RemoveKeyAccessServerFromAttributeResponse, error) { + return &attr.RemoveKeyAccessServerFromAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) AssignKeyAccessServerToValue(_ context.Context, _ *attr.AssignKeyAccessServerToValueRequest) (*attr.AssignKeyAccessServerToValueResponse, error) { + return &attr.AssignKeyAccessServerToValueResponse{}, nil +} +func (*paginatedMockAttributesClient) RemoveKeyAccessServerFromValue(_ context.Context, _ *attr.RemoveKeyAccessServerFromValueRequest) (*attr.RemoveKeyAccessServerFromValueResponse, error) { + return &attr.RemoveKeyAccessServerFromValueResponse{}, nil +} +func (*paginatedMockAttributesClient) AssignPublicKeyToAttribute(_ context.Context, _ *attr.AssignPublicKeyToAttributeRequest) (*attr.AssignPublicKeyToAttributeResponse, error) { + return &attr.AssignPublicKeyToAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) RemovePublicKeyFromAttribute(_ context.Context, _ *attr.RemovePublicKeyFromAttributeRequest) (*attr.RemovePublicKeyFromAttributeResponse, error) { + return &attr.RemovePublicKeyFromAttributeResponse{}, nil +} +func (*paginatedMockAttributesClient) AssignPublicKeyToValue(_ context.Context, _ *attr.AssignPublicKeyToValueRequest) (*attr.AssignPublicKeyToValueResponse, error) { + return &attr.AssignPublicKeyToValueResponse{}, nil +} +func (*paginatedMockAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.RemovePublicKeyFromValueRequest) (*attr.RemovePublicKeyFromValueResponse, error) { + return &attr.RemovePublicKeyFromValueResponse{}, nil +} From 96f90dbc3fa794ef534b3c93729ee0bc2ff23ccf Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 13:48:47 -0400 Subject: [PATCH 19/31] trigger ci From c29dca0a88f71465188afadf65a17c0c5ece1f32 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 14:11:57 -0400 Subject: [PATCH 20/31] linting --- sdk/nanotdf.go | 2 +- sdk/sdk.go | 3 +-- sdk/sdk_test.go | 1 + sdk/tdf_test.go | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index 910dc3f729..e1d5b6c5c0 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -573,7 +573,7 @@ func (c *collectionStore) get(header []byte) ([]byte, bool) { return nil, false } -func (c *collectionStore) close() { +func (c *collectionStore) close() { //nolint:unused // leave for future use c.closeChan <- struct{}{} } diff --git a/sdk/sdk.go b/sdk/sdk.go index 195061ff88..12f02411f6 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -104,11 +104,10 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { } // IF IPC is disabled we build a validated healthy connection to the platform - if !cfg.ipc { + if !cfg.ipc { //nolint:nestif // Most of checks are for errors if IsPlatformEndpointMalformed(platformEndpoint) { return nil, fmt.Errorf("%w [%v]", ErrPlatformEndpointMalformed, platformEndpoint) } - if cfg.shouldValidatePlatformConnectivity { if cfg.coreConn != nil { err = ValidateHealthyPlatformConnection(cfg.coreConn.Endpoint, cfg.coreConn.Client) diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index 30ac4ce09d..12c0627624 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -325,6 +325,7 @@ func TestIsPlatformEndpointMalformed(t *testing.T) { }) } } + func Test_GetType_NanoTDF(t *testing.T) { nano := "TDFMABJsb2NhbGhvc3Q6ODA4MC9rYXOAAQIA2qvjMRfg7b27lT2kf9SwHRkDIg8ZXtfRoiIvdMUHq/gL5AUMfmv4Di8sKCyLkmUm/WITVj5hDeV/z4JmQ0JL7ZxqSmgZoK6TAHvkKhUly4zMEWMRXH8IktKhFKy1+fD+3qwDopqWAO5Nm2nYQqi75atEFckstulpNKg3N+Ul22OHr/ZuR127oPObBDYNRfktBdzoZbEQcPlr8q1B57q6y5SPZFjEzL9weK+uS5bUJWkF3nsHASo2bZw7IPhTZxoFVmCDjwvj6MbxNa7zG6aClHJ162zKxLLnD9TtIHuZ59R7LgiSieipXeExj+ky9OgIw5DfwyUuxsQLtKpMIAFPmLY9Hy2naUJxke0MT1EUBgastCq+YtFGslV9LJo/A8FtrRqludwtM0O+Z9FlAkZ1oNL7M7uOkLrh7eRrv+C1AAAX6FaBQoOtqnmyu6Jp+VzkxDddEeLRUyI=" nanoDecoded, err := base64.StdEncoding.DecodeString(nano) diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index e45fd348c7..333bd0f9d5 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -1918,7 +1918,6 @@ func (s *TDFSuite) startBackend() { var sdkPlatformURL string for i, ki := range kasesToMake { - mux := http.NewServeMux() s.kases[i] = FakeKas{ From 698696c0aaedaa7a767314797422bb0afed46cb3 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 14:29:28 -0400 Subject: [PATCH 21/31] linting --- .../authorization_test_structures.go | 82 +++++++++++++------ 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/service/authorization/authorization_test_structures.go b/service/authorization/authorization_test_structures.go index 0f3759e15b..dbb848a5d9 100644 --- a/service/authorization/authorization_test_structures.go +++ b/service/authorization/authorization_test_structures.go @@ -28,10 +28,8 @@ var ( mockFqn3 = fmt.Sprintf("https://%s/attr/%s/value/%s", mockNamespace, mockAttrName, mockAttrValue3) ) -////// Mock attributes client for testing ///// - -type myAttributesClient struct { -} +// //// Mock attributes client for testing ///// +type myAttributesClient struct{} func (*myAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest) (*attr.ListAttributesResponse, error) { return &listAttributeResp, errListAttributes @@ -44,51 +42,67 @@ func (*myAttributesClient) GetAttributeValuesByFqns(_ context.Context, _ *attr.G func (*myAttributesClient) ListAttributeValues(_ context.Context, _ *attr.ListAttributeValuesRequest) (*attr.ListAttributeValuesResponse, error) { return &attr.ListAttributeValuesResponse{}, nil } + func (*myAttributesClient) GetAttribute(_ context.Context, _ *attr.GetAttributeRequest) (*attr.GetAttributeResponse, error) { return &attr.GetAttributeResponse{}, nil } + func (*myAttributesClient) GetAttributeValue(_ context.Context, _ *attr.GetAttributeValueRequest) (*attr.GetAttributeValueResponse, error) { return &attr.GetAttributeValueResponse{}, nil } + func (*myAttributesClient) CreateAttribute(_ context.Context, _ *attr.CreateAttributeRequest) (*attr.CreateAttributeResponse, error) { return &attr.CreateAttributeResponse{}, nil } + func (*myAttributesClient) UpdateAttribute(_ context.Context, _ *attr.UpdateAttributeRequest) (*attr.UpdateAttributeResponse, error) { return &attr.UpdateAttributeResponse{}, nil } + func (*myAttributesClient) DeactivateAttribute(_ context.Context, _ *attr.DeactivateAttributeRequest) (*attr.DeactivateAttributeResponse, error) { return &attr.DeactivateAttributeResponse{}, nil } + func (*myAttributesClient) CreateAttributeValue(_ context.Context, _ *attr.CreateAttributeValueRequest) (*attr.CreateAttributeValueResponse, error) { return &attr.CreateAttributeValueResponse{}, nil } + func (*myAttributesClient) UpdateAttributeValue(_ context.Context, _ *attr.UpdateAttributeValueRequest) (*attr.UpdateAttributeValueResponse, error) { return &attr.UpdateAttributeValueResponse{}, nil } + func (*myAttributesClient) DeactivateAttributeValue(_ context.Context, _ *attr.DeactivateAttributeValueRequest) (*attr.DeactivateAttributeValueResponse, error) { return &attr.DeactivateAttributeValueResponse{}, nil } + func (*myAttributesClient) AssignKeyAccessServerToAttribute(_ context.Context, _ *attr.AssignKeyAccessServerToAttributeRequest) (*attr.AssignKeyAccessServerToAttributeResponse, error) { return &attr.AssignKeyAccessServerToAttributeResponse{}, nil } + func (*myAttributesClient) RemoveKeyAccessServerFromAttribute(_ context.Context, _ *attr.RemoveKeyAccessServerFromAttributeRequest) (*attr.RemoveKeyAccessServerFromAttributeResponse, error) { return &attr.RemoveKeyAccessServerFromAttributeResponse{}, nil } + func (*myAttributesClient) AssignKeyAccessServerToValue(_ context.Context, _ *attr.AssignKeyAccessServerToValueRequest) (*attr.AssignKeyAccessServerToValueResponse, error) { return &attr.AssignKeyAccessServerToValueResponse{}, nil } + func (*myAttributesClient) RemoveKeyAccessServerFromValue(_ context.Context, _ *attr.RemoveKeyAccessServerFromValueRequest) (*attr.RemoveKeyAccessServerFromValueResponse, error) { return &attr.RemoveKeyAccessServerFromValueResponse{}, nil } + func (*myAttributesClient) AssignPublicKeyToAttribute(_ context.Context, _ *attr.AssignPublicKeyToAttributeRequest) (*attr.AssignPublicKeyToAttributeResponse, error) { return &attr.AssignPublicKeyToAttributeResponse{}, nil } + func (*myAttributesClient) RemovePublicKeyFromAttribute(_ context.Context, _ *attr.RemovePublicKeyFromAttributeRequest) (*attr.RemovePublicKeyFromAttributeResponse, error) { return &attr.RemovePublicKeyFromAttributeResponse{}, nil } + func (*myAttributesClient) AssignPublicKeyToValue(_ context.Context, _ *attr.AssignPublicKeyToValueRequest) (*attr.AssignPublicKeyToValueResponse, error) { return &attr.AssignPublicKeyToValueResponse{}, nil } + func (*myAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.RemovePublicKeyFromValueRequest) (*attr.RemovePublicKeyFromValueResponse, error) { return &attr.RemovePublicKeyFromValueResponse{}, nil } @@ -113,47 +127,47 @@ func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.List return &listSubjectMappings, nil } -func (*mySubjectMappingClient) MatchSubjectMappings(ctx context.Context, req *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { +func (*mySubjectMappingClient) MatchSubjectMappings(_ context.Context, _ *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { return &sm.MatchSubjectMappingsResponse{}, nil } -func (*mySubjectMappingClient) GetSubjectMapping(ctx context.Context, req *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { +func (*mySubjectMappingClient) GetSubjectMapping(_ context.Context, _ *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { return &sm.GetSubjectMappingResponse{}, nil } -func (*mySubjectMappingClient) CreateSubjectMapping(ctx context.Context, req *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { +func (*mySubjectMappingClient) CreateSubjectMapping(_ context.Context, _ *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { return &sm.CreateSubjectMappingResponse{}, nil } -func (*mySubjectMappingClient) UpdateSubjectMapping(ctx context.Context, req *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { +func (*mySubjectMappingClient) UpdateSubjectMapping(_ context.Context, _ *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { return &sm.UpdateSubjectMappingResponse{}, nil } -func (*mySubjectMappingClient) DeleteSubjectMapping(ctx context.Context, req *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { +func (*mySubjectMappingClient) DeleteSubjectMapping(_ context.Context, _ *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { return &sm.DeleteSubjectMappingResponse{}, nil } -func (*mySubjectMappingClient) ListSubjectConditionSets(ctx context.Context, req *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { +func (*mySubjectMappingClient) ListSubjectConditionSets(_ context.Context, _ *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { return &sm.ListSubjectConditionSetsResponse{}, nil } -func (*mySubjectMappingClient) GetSubjectConditionSet(ctx context.Context, req *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { +func (*mySubjectMappingClient) GetSubjectConditionSet(_ context.Context, _ *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { return &sm.GetSubjectConditionSetResponse{}, nil } -func (*mySubjectMappingClient) CreateSubjectConditionSet(ctx context.Context, req *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { +func (*mySubjectMappingClient) CreateSubjectConditionSet(_ context.Context, _ *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { return &sm.CreateSubjectConditionSetResponse{}, nil } -func (*mySubjectMappingClient) UpdateSubjectConditionSet(ctx context.Context, req *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { +func (*mySubjectMappingClient) UpdateSubjectConditionSet(_ context.Context, _ *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { return &sm.UpdateSubjectConditionSetResponse{}, nil } -func (*mySubjectMappingClient) DeleteSubjectConditionSet(ctx context.Context, req *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { +func (*mySubjectMappingClient) DeleteSubjectConditionSet(_ context.Context, _ *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { return &sm.DeleteSubjectConditionSetResponse{}, nil } -func (*mySubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { +func (*mySubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(_ context.Context, _ *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { return &sm.DeleteAllUnmappedSubjectConditionSetsResponse{}, nil } @@ -182,47 +196,47 @@ func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, return &listSubjectMappings, nil } -func (*paginatedMockSubjectMappingClient) MatchSubjectMappings(ctx context.Context, req *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { +func (*paginatedMockSubjectMappingClient) MatchSubjectMappings(_ context.Context, _ *sm.MatchSubjectMappingsRequest) (*sm.MatchSubjectMappingsResponse, error) { return &sm.MatchSubjectMappingsResponse{}, nil } -func (*paginatedMockSubjectMappingClient) GetSubjectMapping(ctx context.Context, req *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { +func (*paginatedMockSubjectMappingClient) GetSubjectMapping(_ context.Context, _ *sm.GetSubjectMappingRequest) (*sm.GetSubjectMappingResponse, error) { return &sm.GetSubjectMappingResponse{}, nil } -func (*paginatedMockSubjectMappingClient) CreateSubjectMapping(ctx context.Context, req *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { +func (*paginatedMockSubjectMappingClient) CreateSubjectMapping(_ context.Context, _ *sm.CreateSubjectMappingRequest) (*sm.CreateSubjectMappingResponse, error) { return &sm.CreateSubjectMappingResponse{}, nil } -func (*paginatedMockSubjectMappingClient) UpdateSubjectMapping(ctx context.Context, req *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { +func (*paginatedMockSubjectMappingClient) UpdateSubjectMapping(_ context.Context, _ *sm.UpdateSubjectMappingRequest) (*sm.UpdateSubjectMappingResponse, error) { return &sm.UpdateSubjectMappingResponse{}, nil } -func (*paginatedMockSubjectMappingClient) DeleteSubjectMapping(ctx context.Context, req *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { +func (*paginatedMockSubjectMappingClient) DeleteSubjectMapping(_ context.Context, _ *sm.DeleteSubjectMappingRequest) (*sm.DeleteSubjectMappingResponse, error) { return &sm.DeleteSubjectMappingResponse{}, nil } -func (*paginatedMockSubjectMappingClient) ListSubjectConditionSets(ctx context.Context, req *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { +func (*paginatedMockSubjectMappingClient) ListSubjectConditionSets(_ context.Context, _ *sm.ListSubjectConditionSetsRequest) (*sm.ListSubjectConditionSetsResponse, error) { return &sm.ListSubjectConditionSetsResponse{}, nil } -func (*paginatedMockSubjectMappingClient) GetSubjectConditionSet(ctx context.Context, req *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { +func (*paginatedMockSubjectMappingClient) GetSubjectConditionSet(_ context.Context, _ *sm.GetSubjectConditionSetRequest) (*sm.GetSubjectConditionSetResponse, error) { return &sm.GetSubjectConditionSetResponse{}, nil } -func (*paginatedMockSubjectMappingClient) CreateSubjectConditionSet(ctx context.Context, req *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { +func (*paginatedMockSubjectMappingClient) CreateSubjectConditionSet(_ context.Context, _ *sm.CreateSubjectConditionSetRequest) (*sm.CreateSubjectConditionSetResponse, error) { return &sm.CreateSubjectConditionSetResponse{}, nil } -func (*paginatedMockSubjectMappingClient) UpdateSubjectConditionSet(ctx context.Context, req *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { +func (*paginatedMockSubjectMappingClient) UpdateSubjectConditionSet(_ context.Context, _ *sm.UpdateSubjectConditionSetRequest) (*sm.UpdateSubjectConditionSetResponse, error) { return &sm.UpdateSubjectConditionSetResponse{}, nil } -func (*paginatedMockSubjectMappingClient) DeleteSubjectConditionSet(ctx context.Context, req *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { +func (*paginatedMockSubjectMappingClient) DeleteSubjectConditionSet(_ context.Context, _ *sm.DeleteSubjectConditionSetRequest) (*sm.DeleteSubjectConditionSetResponse, error) { return &sm.DeleteSubjectConditionSetResponse{}, nil } -func (*paginatedMockSubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(ctx context.Context, req *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { +func (*paginatedMockSubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(_ context.Context, _ *sm.DeleteAllUnmappedSubjectConditionSetsRequest) (*sm.DeleteAllUnmappedSubjectConditionSetsResponse, error) { return &sm.DeleteAllUnmappedSubjectConditionSetsResponse{}, nil } @@ -258,51 +272,67 @@ func (*paginatedMockAttributesClient) GetAttributeValuesByFqns(_ context.Context func (*paginatedMockAttributesClient) ListAttributeValues(_ context.Context, _ *attr.ListAttributeValuesRequest) (*attr.ListAttributeValuesResponse, error) { return &attr.ListAttributeValuesResponse{}, nil } + func (*paginatedMockAttributesClient) GetAttribute(_ context.Context, _ *attr.GetAttributeRequest) (*attr.GetAttributeResponse, error) { return &attr.GetAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) GetAttributeValue(_ context.Context, _ *attr.GetAttributeValueRequest) (*attr.GetAttributeValueResponse, error) { return &attr.GetAttributeValueResponse{}, nil } + func (*paginatedMockAttributesClient) CreateAttribute(_ context.Context, _ *attr.CreateAttributeRequest) (*attr.CreateAttributeResponse, error) { return &attr.CreateAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) UpdateAttribute(_ context.Context, _ *attr.UpdateAttributeRequest) (*attr.UpdateAttributeResponse, error) { return &attr.UpdateAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) DeactivateAttribute(_ context.Context, _ *attr.DeactivateAttributeRequest) (*attr.DeactivateAttributeResponse, error) { return &attr.DeactivateAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) CreateAttributeValue(_ context.Context, _ *attr.CreateAttributeValueRequest) (*attr.CreateAttributeValueResponse, error) { return &attr.CreateAttributeValueResponse{}, nil } + func (*paginatedMockAttributesClient) UpdateAttributeValue(_ context.Context, _ *attr.UpdateAttributeValueRequest) (*attr.UpdateAttributeValueResponse, error) { return &attr.UpdateAttributeValueResponse{}, nil } + func (*paginatedMockAttributesClient) DeactivateAttributeValue(_ context.Context, _ *attr.DeactivateAttributeValueRequest) (*attr.DeactivateAttributeValueResponse, error) { return &attr.DeactivateAttributeValueResponse{}, nil } + func (*paginatedMockAttributesClient) AssignKeyAccessServerToAttribute(_ context.Context, _ *attr.AssignKeyAccessServerToAttributeRequest) (*attr.AssignKeyAccessServerToAttributeResponse, error) { return &attr.AssignKeyAccessServerToAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) RemoveKeyAccessServerFromAttribute(_ context.Context, _ *attr.RemoveKeyAccessServerFromAttributeRequest) (*attr.RemoveKeyAccessServerFromAttributeResponse, error) { return &attr.RemoveKeyAccessServerFromAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) AssignKeyAccessServerToValue(_ context.Context, _ *attr.AssignKeyAccessServerToValueRequest) (*attr.AssignKeyAccessServerToValueResponse, error) { return &attr.AssignKeyAccessServerToValueResponse{}, nil } + func (*paginatedMockAttributesClient) RemoveKeyAccessServerFromValue(_ context.Context, _ *attr.RemoveKeyAccessServerFromValueRequest) (*attr.RemoveKeyAccessServerFromValueResponse, error) { return &attr.RemoveKeyAccessServerFromValueResponse{}, nil } + func (*paginatedMockAttributesClient) AssignPublicKeyToAttribute(_ context.Context, _ *attr.AssignPublicKeyToAttributeRequest) (*attr.AssignPublicKeyToAttributeResponse, error) { return &attr.AssignPublicKeyToAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) RemovePublicKeyFromAttribute(_ context.Context, _ *attr.RemovePublicKeyFromAttributeRequest) (*attr.RemovePublicKeyFromAttributeResponse, error) { return &attr.RemovePublicKeyFromAttributeResponse{}, nil } + func (*paginatedMockAttributesClient) AssignPublicKeyToValue(_ context.Context, _ *attr.AssignPublicKeyToValueRequest) (*attr.AssignPublicKeyToValueResponse, error) { return &attr.AssignPublicKeyToValueResponse{}, nil } + func (*paginatedMockAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.RemovePublicKeyFromValueRequest) (*attr.RemovePublicKeyFromValueResponse, error) { return &attr.RemovePublicKeyFromValueResponse{}, nil } From a2f4c5240eeec06145251f86ac5bf482dd5e4b90 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 14:32:34 -0400 Subject: [PATCH 22/31] linting --- .../authorization/authorization_test_structures.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/service/authorization/authorization_test_structures.go b/service/authorization/authorization_test_structures.go index dbb848a5d9..e9f4d764f5 100644 --- a/service/authorization/authorization_test_structures.go +++ b/service/authorization/authorization_test_structures.go @@ -108,8 +108,7 @@ func (*myAttributesClient) RemovePublicKeyFromValue(_ context.Context, _ *attr.R } // // Mock ERS Client for testing ///// -type myERSClient struct { -} +type myERSClient struct{} func (*myERSClient) CreateEntityChainFromJwt(_ context.Context, _ *entityresolution.CreateEntityChainFromJwtRequest) (*entityresolution.CreateEntityChainFromJwtResponse, error) { return &createEntityChainResp, nil @@ -120,8 +119,7 @@ func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.Resol } // // Mock Subject Mapping Client for testing ///// -type mySubjectMappingClient struct { -} +type mySubjectMappingClient struct{} func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest) (*sm.ListSubjectMappingsResponse, error) { return &listSubjectMappings, nil @@ -172,8 +170,7 @@ func (*mySubjectMappingClient) DeleteAllUnmappedSubjectConditionSets(_ context.C } // // Mock paginated Subject Mapping Client for testing ///// -type paginatedMockSubjectMappingClient struct { -} +type paginatedMockSubjectMappingClient struct{} var ( smPaginationOffset = 3 @@ -241,8 +238,7 @@ func (*paginatedMockSubjectMappingClient) DeleteAllUnmappedSubjectConditionSets( } // // Mock paginated attributs client for testing //// -type paginatedMockAttributesClient struct { -} +type paginatedMockAttributesClient struct{} var ( attrPaginationOffset = 3 From d5c290b56debc98ea3aa356d5855a5bb12f7411e Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Fri, 16 May 2025 14:46:27 -0400 Subject: [PATCH 23/31] undo listener changes --- service/internal/server/server.go | 15 ++++++++------- service/pkg/server/start.go | 22 +++++++++++----------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 21468f4430..46748b529b 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -232,12 +232,6 @@ func NewOpenTDFServer(config Config, logger *logger.Logger) (*OpenTDFServer, err logger: logger, } - listener, err := o.openHTTPServerPort() - if err != nil { - return nil, err - } - o.Listener = listener - if !config.CryptoProvider.IsEmpty() { // Create crypto provider logger.Info("creating crypto provider", slog.String("type", config.CryptoProvider.Type)) @@ -448,7 +442,14 @@ func (s OpenTDFServer) Start() error { s.ConnectRPCInProcess.Mux.Handle(grpcreflect.NewHandlerV1(reflector)) s.ConnectRPCInProcess.Mux.Handle(grpcreflect.NewHandlerV1Alpha(reflector)) - go s.startHTTPServer(s.Listener) + ln, err := s.openHTTPServerPort() + if err != nil { + return err + } + s.Listener = ln + + // Start Http Server + go s.startHTTPServer(ln) return nil } diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 7a814198ea..760b2db9bb 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -86,17 +86,6 @@ func Start(f ...StartOptions) error { logger.Debug("config loaded", slog.Any("config", cfg.LogValue())) - // If the mode is not all, does not include both core and entityresolution, or is not entityresolution on its own, we need to have a valid SDK config - // entityresolution does not connect to other services and can run on its own - // core only connects to entityresolution - if !(slices.Contains(cfg.Mode, "all") || // no config required for all mode - (slices.Contains(cfg.Mode, "core") && slices.Contains(cfg.Mode, "entityresolution")) || // or core and entityresolution modes togethor - (slices.Contains(cfg.Mode, "entityresolution") && len(cfg.Mode) == 1)) && // or entityresolution on its own - cfg.SDKConfig == (config.SDKConfig{}) { - logger.Error("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") - return errors.New("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") - } - logger.Info("starting opentdf services") // Set allowed public routes when platform is being extended @@ -198,6 +187,17 @@ func Start(f ...StartOptions) error { oidcconfig *auth.OIDCConfiguration ) + // If the mode is not all, does not include both core and entityresolution, or is not entityresolution on its own, we need to have a valid SDK config + // entityresolution does not connect to other services and can run on its own + // core only connects to entityresolution + if !(slices.Contains(cfg.Mode, "all") || // no config required for all mode + (slices.Contains(cfg.Mode, "core") && slices.Contains(cfg.Mode, "entityresolution")) || // or core and entityresolution modes togethor + (slices.Contains(cfg.Mode, "entityresolution") && len(cfg.Mode) == 1)) && // or entityresolution on its own + cfg.SDKConfig == (config.SDKConfig{}) { + logger.Error("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") + return errors.New("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") + } + // If client credentials are provided, use them if cfg.SDKConfig.ClientID != "" && cfg.SDKConfig.ClientSecret != "" { sdkOptions = append(sdkOptions, sdk.WithClientCredentials(cfg.SDKConfig.ClientID, cfg.SDKConfig.ClientSecret, nil)) From 32d14efc80ca82f531b9b14f87ebb543236a77d2 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 19 May 2025 12:48:51 -0400 Subject: [PATCH 24/31] remove otdfctl ref --- .github/workflows/checks.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 955c3e5c63..fb568c81ef 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -344,7 +344,6 @@ jobs: focus-sdk: go # use commit instead of ref so we can "go get" specific sdk version platform-ref: ${{ github.event.pull_request.head.sha || github.sha }} lts - otdfctl-ref: 107e016c326564234757a55e55086fdf66e83078 # test latest otdfctl CLI 'main' against platform PR branch otdfctl-test: From 16016ced7df9646ca3bcf39307fe06482ae24e59 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Mon, 19 May 2025 13:55:56 -0400 Subject: [PATCH 25/31] add extra client options to platform validation and get cfg --- sdk/sdk.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdk/sdk.go b/sdk/sdk.go index 12f02411f6..7b2a72f4df 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -110,12 +110,12 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { } if cfg.shouldValidatePlatformConnectivity { if cfg.coreConn != nil { - err = ValidateHealthyPlatformConnection(cfg.coreConn.Endpoint, cfg.coreConn.Client) + err = validateHealthyPlatformConnection(cfg.coreConn.Endpoint, cfg.coreConn.Client, cfg.coreConn.Options) if err != nil { return nil, err } } else { - err = ValidateHealthyPlatformConnection(platformEndpoint, cfg.httpClient) + err = validateHealthyPlatformConnection(platformEndpoint, cfg.httpClient, cfg.extraClientOptions) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { return nil, errors.Join(ErrPlatformConfigFailed, err) } } else { - pcfg, err = getPlatformConfiguration(&ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient}) + pcfg, err = getPlatformConfiguration(&ConnectRPCConnection{Endpoint: platformEndpoint, Client: cfg.httpClient, Options: cfg.extraClientOptions}) if err != nil { return nil, errors.Join(ErrPlatformConfigFailed, err) } @@ -385,10 +385,11 @@ func IsValidNanoTdf(reader io.ReadSeeker) (bool, error) { } // Test connectability to the platform and validate a healthy status -func ValidateHealthyPlatformConnection(platformEndpoint string, httpClient *http.Client) error { +func validateHealthyPlatformConnection(platformEndpoint string, httpClient *http.Client, options []connect.ClientOption) error { healthClient := connect.NewClient[healthpb.HealthCheckRequest, healthpb.HealthCheckResponse]( httpClient, platformEndpoint+"/grpc.health.v1.Health/Check", + options..., ) res, err := healthClient.CallUnary( context.Background(), From bf886ae9d115b0abf3c113dcaebf8b411684e7e0 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Tue, 20 May 2025 12:28:09 -0400 Subject: [PATCH 26/31] keep close until otdfctl updates --- sdk/sdk.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdk/sdk.go b/sdk/sdk.go index 7b2a72f4df..e4a7b6848c 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -258,6 +258,11 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { return ts, err } +// TODO: Remove after otdfctl updates +func (s SDK) Close() error { + return nil +} + // Conn returns the underlying http connection func (s SDK) Conn() *ConnectRPCConnection { return s.conn From aae3900704fa29ec408228b415a354241e3528dc Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Tue, 20 May 2025 14:18:39 -0400 Subject: [PATCH 27/31] add back close, add readme for codegen --- examples/cmd/attributes.go | 4 ++++ examples/cmd/authorization.go | 1 + examples/cmd/decrypt.go | 1 + examples/cmd/kas.go | 3 +++ sdk/internal/codegen/README.md | 27 +++++++++++++++++++++++++ sdk/internal/codegen/runner/generate.go | 5 ++++- sdk/nanotdf.go | 2 +- sdk/sdk.go | 4 +++- service/pkg/server/start.go | 2 ++ 9 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 sdk/internal/codegen/README.md diff --git a/examples/cmd/attributes.go b/examples/cmd/attributes.go index aaa397cec0..3715e632a6 100644 --- a/examples/cmd/attributes.go +++ b/examples/cmd/attributes.go @@ -106,6 +106,7 @@ func listAttributes(cmd *cobra.Command) error { slog.Error("could not connect", slog.Any("error", err)) return err } + defer s.Close() ctx := cmd.Context() @@ -224,6 +225,7 @@ func addAttribute(cmd *cobra.Command) error { slog.Error("newSDK", slog.Any("error", err)) return err } + defer s.Close() are := regexp.MustCompile(`^(https?://[\w./]+)/attr/([^/\s]*)$`) m := are.FindStringSubmatch(attr) @@ -258,6 +260,7 @@ func removeAttribute(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } + defer s.Close() are := regexp.MustCompile(`^(https?://[\w./]+)/attr/([^/\s]*)$`) m := are.FindStringSubmatch(attr) @@ -332,6 +335,7 @@ func assignAttribute(cmd *cobra.Command, assign bool) error { slog.Error("could not connect", "err", err) return err } + defer s.Close() are := regexp.MustCompile(`^(https?://[\w./]+)/attr/([^/\s]*)$`) m := are.FindStringSubmatch(attr) diff --git a/examples/cmd/authorization.go b/examples/cmd/authorization.go index 39809cc513..f5949f4fdd 100644 --- a/examples/cmd/authorization.go +++ b/examples/cmd/authorization.go @@ -25,6 +25,7 @@ func authorizationExamples() error { slog.Error("could not connect", slog.Any("error", err)) return err } + defer s.Close() // request decision on "TRANSMIT" Action actions := []*policy.Action{{ diff --git a/examples/cmd/decrypt.go b/examples/cmd/decrypt.go index 0468af5eb2..8105e8c84a 100644 --- a/examples/cmd/decrypt.go +++ b/examples/cmd/decrypt.go @@ -56,6 +56,7 @@ func decrypt(cmd *cobra.Command, args []string) error { } } } + client.Close() return nil } diff --git a/examples/cmd/kas.go b/examples/cmd/kas.go index 4555b9bce6..bff3dcb094 100644 --- a/examples/cmd/kas.go +++ b/examples/cmd/kas.go @@ -71,6 +71,7 @@ func listKases(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } + defer s.Close() r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), &kasregistry.ListKeyAccessServersRequest{}) if err != nil { @@ -158,6 +159,7 @@ func updateKas(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } + defer s.Close() var pk *policy.PublicKey switch { @@ -206,6 +208,7 @@ func removeKas(cmd *cobra.Command) error { slog.Error("could not connect", "err", err) return err } + defer s.Close() r, err := s.KeyAccessServerRegistry.ListKeyAccessServers(cmd.Context(), &kasregistry.ListKeyAccessServersRequest{}) if err != nil { diff --git a/sdk/internal/codegen/README.md b/sdk/internal/codegen/README.md new file mode 100644 index 0000000000..c96c5d961a --- /dev/null +++ b/sdk/internal/codegen/README.md @@ -0,0 +1,27 @@ +# SDK Internal Codegen + +## Overview +This folder contains the code generation logic for the SDK's internal components. It automates the creation of ConnectRPC wrapper clients, to ensure consistency and reduce manual effort. These clients have similar interfaces to the GRPC proto generated clients allowing for ease of transition to ConnectRPC client-side. + +--- + +## What It Generates +The code generation in this folder focuses on: +1. ConnectRPC wrapper clients for various platform services +2. Interfaces for each wrapper client + +The clients generated are defined in `clientsToGenerateList` in `runner/generate.go`. + +--- + +## How to Run Code Generation +To generate the internal SDK code: + +```bash +go run ./sdk/internal/codegen +``` + +Or use the provided Makefile command +```bash +make connect-wrapper-generate +``` \ No newline at end of file diff --git a/sdk/internal/codegen/runner/generate.go b/sdk/internal/codegen/runner/generate.go index 803ebf5e7e..d09bf1fea4 100644 --- a/sdk/internal/codegen/runner/generate.go +++ b/sdk/internal/codegen/runner/generate.go @@ -122,6 +122,9 @@ func Generate() error { currentDir, err = getCurrentFileDir() outputPath := filepath.Join(currentDir, "..", "..", "..", "sdkconnect", packageName+".go") err = os.WriteFile(outputPath, []byte(code), 0o644) //nolint:gosec // ignore G306 + if err != nil { + slog.Error("Error writing file", "file", outputPath, "error", err) + } found = true return false // stop traversal } @@ -148,7 +151,7 @@ func Generate() error { func getCurrentFileDir() (string, error) { _, filename, _, ok := runtime.Caller(0) if !ok { - return "", errors.New("could not get caller information") + return "", errors.New("could not get caller file (generate.go) working directory") } return filepath.Dir(filename), nil } diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index e1d5b6c5c0..910dc3f729 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -573,7 +573,7 @@ func (c *collectionStore) get(header []byte) ([]byte, bool) { return nil, false } -func (c *collectionStore) close() { //nolint:unused // leave for future use +func (c *collectionStore) close() { c.closeChan <- struct{}{} } diff --git a/sdk/sdk.go b/sdk/sdk.go index e4a7b6848c..15c18b73a4 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -258,8 +258,10 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { return ts, err } -// TODO: Remove after otdfctl updates func (s SDK) Close() error { + if s.collectionStore != nil { + s.collectionStore.close() + } return nil } diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 760b2db9bb..02d0b63558 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -297,6 +297,8 @@ func Start(f ...StartOptions) error { } } + defer client.Close() + logger.Info("starting services") gatewayCleanup, err := startServices(ctx, cfg, otdf, client, logger, svcRegistry) if err != nil { From ba0d53c5e1ec28cac1a783c70cdaf82315c1fef0 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Tue, 20 May 2025 15:29:09 -0400 Subject: [PATCH 28/31] address comments --- .gitattributes | 3 ++- Makefile | 2 +- examples/cmd/benchmark.go | 4 ++-- examples/cmd/benchmark_bulk.go | 4 ++-- sdk/audit/metadata_adding_interceptor_test.go | 18 ++++++------------ service/pkg/server/services.go | 8 ++++---- service/pkg/server/services_test.go | 7 +++---- service/pkg/server/start.go | 4 +--- service/pkg/server/start_test.go | 4 +--- 9 files changed, 22 insertions(+), 32 deletions(-) diff --git a/.gitattributes b/.gitattributes index 2669ff174a..db0d35156a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,5 @@ docs/grpc/** linguist-generated=true docs/openapi/** linguist-generated=true service/policy/db/*.sql.go linguist-generated=true -service/policy/db/models.go linguist-generated=true \ No newline at end of file +service/policy/db/models.go linguist-generated=true +sdk/sdkconnect/** linguist-generated=true \ No newline at end of file diff --git a/Makefile b/Makefile index 254b8f3c59..9867407efd 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # make # To run all lint checks: `LINT_OPTIONS= make lint` -.PHONY: all build clean docker-build fix fmt go-lint license lint proto-generate connect-wrapper-generate proto-lint sdk/sdk test tidy toolcheck +.PHONY: all build clean connect-wrapper-generate docker-build fix fmt go-lint license lint proto-generate proto-lint sdk/sdk test tidy toolcheck MODS=protocol/go lib/ocrypto lib/fixtures lib/flattening lib/identifier sdk service examples HAND_MODS=lib/ocrypto lib/fixtures lib/flattening lib/identifier sdk service examples diff --git a/examples/cmd/benchmark.go b/examples/cmd/benchmark.go index 3bb95eab8e..106d0b2d23 100644 --- a/examples/cmd/benchmark.go +++ b/examples/cmd/benchmark.go @@ -106,7 +106,7 @@ func runBenchmark(cmd *cobra.Command, _ []string) error { return err } nanoTDFConfig.EnableECDSAPolicyBinding() - if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) } else { err = nanoTDFConfig.SetKasURL(fmt.Sprintf("https://%s/kas", "localhost:8080")) @@ -128,7 +128,7 @@ func runBenchmark(cmd *cobra.Command, _ []string) error { // } } else { opts := []sdk.TDFOption{sdk.WithDataAttributes(dataAttributes...), sdk.WithAutoconfigure(false)} - if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ URL: "http://localhost:8080", diff --git a/examples/cmd/benchmark_bulk.go b/examples/cmd/benchmark_bulk.go index bcfd3dac79..b20a0c5456 100644 --- a/examples/cmd/benchmark_bulk.go +++ b/examples/cmd/benchmark_bulk.go @@ -62,7 +62,7 @@ func runBenchmarkBulk(cmd *cobra.Command, _ []string) error { } nanoTDFConfig.EnableECDSAPolicyBinding() // if plaintext or platform endpoint is http, set kas url to http, otherwise https - if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { err = nanoTDFConfig.SetKasURL(fmt.Sprintf("http://%s/kas", "localhost:8080")) } else { err = nanoTDFConfig.SetKasURL(fmt.Sprintf("https://%s/kas", "localhost:8080")) @@ -84,7 +84,7 @@ func runBenchmarkBulk(cmd *cobra.Command, _ []string) error { } } else { opts := []sdk.TDFOption{sdk.WithDataAttributes(dataAttributes...), sdk.WithAutoconfigure(false)} - if insecurePlaintextConn || strings.Contains(platformEndpoint, "http://") { + if insecurePlaintextConn || strings.HasPrefix(platformEndpoint, "http://") { opts = append(opts, sdk.WithKasInformation( sdk.KASInfo{ URL: "http://localhost:8080", diff --git a/sdk/audit/metadata_adding_interceptor_test.go b/sdk/audit/metadata_adding_interceptor_test.go index 68756303ab..25c384cc8c 100644 --- a/sdk/audit/metadata_adding_interceptor_test.go +++ b/sdk/audit/metadata_adding_interceptor_test.go @@ -11,6 +11,8 @@ import ( "github.com/google/uuid" "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/kas/kasconnect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -57,13 +59,9 @@ func TestAddingAuditMetadataToOutgoingRequest(t *testing.T) { ctx = context.WithValue(ctx, ActorIDContextKey, contextActorID) _, err := clientConnect.PublicKey(ctx, connect.NewRequest(&kas.PublicKeyRequest{})) - if err != nil { - t.Fatalf("error making call: %v", err) - } + require.NoError(t, err) _, err = clientGrpc.PublicKey(ctx, &kas.PublicKeyRequest{}) - if err != nil { - t.Fatalf("error making call: %v", err) - } + require.NoError(t, err) for _, ids := range []struct { actorID string @@ -72,12 +70,8 @@ func TestAddingAuditMetadataToOutgoingRequest(t *testing.T) { {requestID: serverConnect.requestID, actorID: serverConnect.actorID}, {requestID: serverGrpc.requestID, actorID: serverGrpc.actorID}, } { - if ids.requestID != contextRequestID { - t.Fatalf("request ID did not match: %v", serverConnect.requestID) - } - if ids.requestID != contextRequestID { - t.Fatalf("request ID did not match: %v", serverGrpc.requestID) - } + assert.Equal(t, contextRequestID, ids.requestID, "request ID did not match") + assert.Equal(t, contextActorID, ids.actorID, "actor ID did not match") } } diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index ed4e5f904f..32df42aa28 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -164,7 +164,7 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF var err error svcDBClient, err = newServiceDBClient(ctx, cfg.Logger, cfg.DB, tracer, ns, svc.DBMigrations()) if err != nil { - return nil, err + return func() {}, err } } @@ -179,11 +179,11 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF Tracer: tracer, }) if err != nil { - return nil, err + return func() {}, err } if err := svc.RegisterConfigUpdateHook(ctx, cfg.AddOnConfigChangeHook); err != nil { - return nil, fmt.Errorf("failed to register config update hook: %w", err) + return func() {}, fmt.Errorf("failed to register config update hook: %w", err) } // Register Connect RPC Services @@ -203,7 +203,7 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF logger.Info("service did not register a grpc gateway handler", slog.String("namespace", ns)) } else if gatewayCleanup == nil { gatewayCleanup = func() { - slog.Info("executing cleanup") + slog.Debug("executing cleanup") if grpcConn != nil { grpcConn.Close() } diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index 0d85cb052c..62b80ac530 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -272,10 +272,9 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { "foobar": {}, }, }, otdf, nil, newLogger, registry) - if cleanup != nil { - // call cleanup function - defer cleanup() - } + + // call cleanup function + defer cleanup() suite.Require().NoError(err) // require.NotNil(t, cF) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 02d0b63558..7e36b44399 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -305,9 +305,7 @@ func Start(f ...StartOptions) error { logger.Error("issue starting services", slog.String("error", err.Error())) return fmt.Errorf("issue starting services: %w", err) } - if gatewayCleanup != nil { - defer gatewayCleanup() - } + defer gatewayCleanup() // Start watching the configuration for changes with registered config change service hooks if err := cfg.Watch(ctx); err != nil { diff --git a/service/pkg/server/start_test.go b/service/pkg/server/start_test.go index 2bc4fc85a6..8db33168f2 100644 --- a/service/pkg/server/start_test.go +++ b/service/pkg/server/start_test.go @@ -263,9 +263,7 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered() { }, }, s, nil, logger, registry) require.NoError(t, err) - if cleanup != nil { - defer cleanup() - } + defer cleanup() require.NoError(t, s.Start()) defer s.Stop() From 222f11aa072c399f135f807959937e62569a422d Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Tue, 20 May 2025 15:36:16 -0400 Subject: [PATCH 29/31] codegen updates from main --- sdk/go.mod | 3 ++ sdk/go.sum | 3 +- sdk/internal/codegen/runner/generate.go | 6 +++ sdk/sdk.go | 4 +- sdk/sdkconnect/authorizationv2.go | 60 +++++++++++++++++++++++++ 5 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 sdk/sdkconnect/authorizationv2.go diff --git a/sdk/go.mod b/sdk/go.mod index db02052991..0addc797cd 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -16,6 +16,7 @@ require ( github.com/opentdf/platform/protocol/go v0.3.3 github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.34.0 + github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/oauth2 v0.30.0 golang.org/x/tools v0.33.0 google.golang.org/grpc v1.72.1 @@ -76,6 +77,8 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect diff --git a/sdk/go.sum b/sdk/go.sum index e1a78524c5..42d229e54b 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -162,9 +162,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= -github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= -github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= diff --git a/sdk/internal/codegen/runner/generate.go b/sdk/internal/codegen/runner/generate.go index d09bf1fea4..1371a52486 100644 --- a/sdk/internal/codegen/runner/generate.go +++ b/sdk/internal/codegen/runner/generate.go @@ -33,6 +33,12 @@ var clientsToGenerateList = []clientsToGenerate{ grpcClientInterface: "AuthorizationServiceClient", grpcPackagePath: "github.com/opentdf/platform/protocol/go/authorization", }, + { + grpcClientInterface: "AuthorizationServiceClient", + suffix: "V2", + grpcPackagePath: "github.com/opentdf/platform/protocol/go/authorization/v2", + packageNameOverride: "authorizationv2", + }, { grpcClientInterface: "EntityResolutionServiceClient", grpcPackagePath: "github.com/opentdf/platform/protocol/go/entityresolution", diff --git a/sdk/sdk.go b/sdk/sdk.go index 025fe12994..44f4afde01 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -58,7 +58,7 @@ type SDK struct { Actions sdkconnect.ActionServiceClient Attributes sdkconnect.AttributesServiceClient Authorization sdkconnect.AuthorizationServiceClient - AuthorizationV2 sdkconnect.AuthorizationServiceClient + AuthorizationV2 sdkconnect.AuthorizationServiceClientV2 EntityResoution sdkconnect.EntityResolutionServiceClient EntityResolutionV2 sdkconnect.EntityResolutionServiceClientV2 KeyAccessServerRegistry sdkconnect.KeyAccessServerRegistryServiceClient @@ -194,7 +194,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { Unsafe: sdkconnect.NewUnsafeServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), KeyAccessServerRegistry: sdkconnect.NewKeyAccessServerRegistryServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), Authorization: sdkconnect.NewAuthorizationServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), - AuthorizationV2: sdkconnect.NewAuthorizationServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), + AuthorizationV2: sdkconnect.NewAuthorizationServiceClientV2ConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), EntityResoution: sdkconnect.NewEntityResolutionServiceClientConnectWrapper(ersConn.Client, ersConn.Endpoint, ersConn.Options...), EntityResolutionV2: sdkconnect.NewEntityResolutionServiceClientV2ConnectWrapper(ersConn.Client, ersConn.Endpoint, ersConn.Options...), KeyManagement: sdkconnect.NewKeyManagementServiceClientConnectWrapper(platformConn.Client, platformConn.Endpoint, platformConn.Options...), diff --git a/sdk/sdkconnect/authorizationv2.go b/sdk/sdkconnect/authorizationv2.go new file mode 100644 index 0000000000..80e5b1bce3 --- /dev/null +++ b/sdk/sdkconnect/authorizationv2.go @@ -0,0 +1,60 @@ +// Wrapper for AuthorizationServiceClient (generated code) DO NOT EDIT +package sdkconnect + +import ( + "connectrpc.com/connect" + "context" + "github.com/opentdf/platform/protocol/go/authorization/v2" + "github.com/opentdf/platform/protocol/go/authorization/v2/authorizationv2connect" +) + +type AuthorizationServiceClientV2ConnectWrapper struct { + authorizationv2connect.AuthorizationServiceClient +} + +func NewAuthorizationServiceClientV2ConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) *AuthorizationServiceClientV2ConnectWrapper { + return &AuthorizationServiceClientV2ConnectWrapper{AuthorizationServiceClient: authorizationv2connect.NewAuthorizationServiceClient(httpClient, baseURL, opts...)} +} + +type AuthorizationServiceClientV2 interface { + GetDecision(ctx context.Context, req *authorizationv2.GetDecisionRequest) (*authorizationv2.GetDecisionResponse, error) + GetDecisionMultiResource(ctx context.Context, req *authorizationv2.GetDecisionMultiResourceRequest) (*authorizationv2.GetDecisionMultiResourceResponse, error) + GetDecisionBulk(ctx context.Context, req *authorizationv2.GetDecisionBulkRequest) (*authorizationv2.GetDecisionBulkResponse, error) + GetEntitlements(ctx context.Context, req *authorizationv2.GetEntitlementsRequest) (*authorizationv2.GetEntitlementsResponse, error) +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetDecision(ctx context.Context, req *authorizationv2.GetDecisionRequest) (*authorizationv2.GetDecisionResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecision(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetDecisionMultiResource(ctx context.Context, req *authorizationv2.GetDecisionMultiResourceRequest) (*authorizationv2.GetDecisionMultiResourceResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisionMultiResource(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetDecisionBulk(ctx context.Context, req *authorizationv2.GetDecisionBulkRequest) (*authorizationv2.GetDecisionBulkResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetDecisionBulk(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} + +func (w *AuthorizationServiceClientV2ConnectWrapper) GetEntitlements(ctx context.Context, req *authorizationv2.GetEntitlementsRequest) (*authorizationv2.GetEntitlementsResponse, error) { + // Wrap Connect RPC client request + res, err := w.AuthorizationServiceClient.GetEntitlements(ctx, connect.NewRequest(req)) + if res == nil { + return nil, err + } + return res.Msg, err +} From 95aca4af410d91192e5c34ea490772a9880beb30 Mon Sep 17 00:00:00 2001 From: Elizabeth Healy Date: Tue, 20 May 2025 15:50:08 -0400 Subject: [PATCH 30/31] fix gateway cleanup for tests --- service/pkg/server/services.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index b0801f2765..e8afac538f 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -234,6 +234,9 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF } } + if gatewayCleanup == nil { + gatewayCleanup = func() {} + } return gatewayCleanup, nil } From c8fa1238264ef6d502e4a6bab60bd4ea8eb8bb65 Mon Sep 17 00:00:00 2001 From: Dave Mihalcik Date: Tue, 20 May 2025 16:15:38 -0400 Subject: [PATCH 31/31] sort .gitattributes --- .gitattributes | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitattributes b/.gitattributes index db0d35156a..2bf7db77c1 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,5 +1,5 @@ docs/grpc/** linguist-generated=true docs/openapi/** linguist-generated=true +sdk/sdkconnect/** linguist-generated=true service/policy/db/*.sql.go linguist-generated=true service/policy/db/models.go linguist-generated=true -sdk/sdkconnect/** linguist-generated=true \ No newline at end of file