|
| 1 | +package claims |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "log/slog" |
| 7 | + "strconv" |
| 8 | + |
| 9 | + "connectrpc.com/connect" |
| 10 | + "github.com/lestrrat-go/jwx/v2/jwt" |
| 11 | + "github.com/opentdf/platform/protocol/go/entity" |
| 12 | + entityresolutionV2 "github.com/opentdf/platform/protocol/go/entityresolution/v2" |
| 13 | + auth "github.com/opentdf/platform/service/authorization" |
| 14 | + "github.com/opentdf/platform/service/logger" |
| 15 | + "github.com/opentdf/platform/service/pkg/config" |
| 16 | + "github.com/opentdf/platform/service/pkg/serviceregistry" |
| 17 | + "go.opentelemetry.io/otel/trace" |
| 18 | + "google.golang.org/protobuf/encoding/protojson" |
| 19 | + "google.golang.org/protobuf/types/known/anypb" |
| 20 | + "google.golang.org/protobuf/types/known/structpb" |
| 21 | +) |
| 22 | + |
| 23 | +type EntityResolutionServiceV2 struct { |
| 24 | + entityresolutionV2.UnimplementedEntityResolutionServiceServer |
| 25 | + logger *logger.Logger |
| 26 | + trace.Tracer |
| 27 | +} |
| 28 | + |
| 29 | +func RegisterClaimsERS(_ config.ServiceConfig, logger *logger.Logger) (EntityResolutionServiceV2, serviceregistry.HandlerServer) { |
| 30 | + claimsSVC := EntityResolutionServiceV2{logger: logger} |
| 31 | + return claimsSVC, nil |
| 32 | +} |
| 33 | + |
| 34 | +func (s EntityResolutionServiceV2) ResolveEntities(ctx context.Context, req *connect.Request[entityresolutionV2.ResolveEntitiesRequest]) (*connect.Response[entityresolutionV2.ResolveEntitiesResponse], error) { |
| 35 | + resp, err := EntityResolution(ctx, req.Msg, s.logger) |
| 36 | + return connect.NewResponse(&resp), err |
| 37 | +} |
| 38 | + |
| 39 | +func (s EntityResolutionServiceV2) CreateEntityChainsFromTokens(ctx context.Context, req *connect.Request[entityresolutionV2.CreateEntityChainsFromTokensRequest]) (*connect.Response[entityresolutionV2.CreateEntityChainsFromTokensResponse], error) { |
| 40 | + ctx, span := s.Tracer.Start(ctx, "CreateEntityChainsFromTokens") |
| 41 | + defer span.End() |
| 42 | + |
| 43 | + resp, err := CreateEntityChainsFromTokens(ctx, req.Msg, s.logger) |
| 44 | + return connect.NewResponse(&resp), err |
| 45 | +} |
| 46 | + |
| 47 | +func CreateEntityChainsFromTokens( |
| 48 | + _ context.Context, |
| 49 | + req *entityresolutionV2.CreateEntityChainsFromTokensRequest, |
| 50 | + _ *logger.Logger, |
| 51 | +) (entityresolutionV2.CreateEntityChainsFromTokensResponse, error) { |
| 52 | + entityChains := []*entity.EntityChain{} |
| 53 | + // for each token in the tokens form an entity chain |
| 54 | + for _, tok := range req.GetTokens() { |
| 55 | + entities, err := getEntitiesFromToken(tok.GetJwt()) |
| 56 | + if err != nil { |
| 57 | + return entityresolutionV2.CreateEntityChainsFromTokensResponse{}, err |
| 58 | + } |
| 59 | + entityChains = append(entityChains, &entity.EntityChain{EphemeralId: tok.GetEphemeralId(), Entities: entities}) |
| 60 | + } |
| 61 | + |
| 62 | + return entityresolutionV2.CreateEntityChainsFromTokensResponse{EntityChains: entityChains}, nil |
| 63 | +} |
| 64 | + |
| 65 | +func EntityResolution(_ context.Context, |
| 66 | + req *entityresolutionV2.ResolveEntitiesRequest, logger *logger.Logger, |
| 67 | +) (entityresolutionV2.ResolveEntitiesResponse, error) { |
| 68 | + payload := req.GetEntities() |
| 69 | + var resolvedEntities []*entityresolutionV2.EntityRepresentation |
| 70 | + |
| 71 | + for idx, ident := range payload { |
| 72 | + entityStruct := &structpb.Struct{} |
| 73 | + switch ident.GetEntityType().(type) { |
| 74 | + case *entity.Entity_Claims: |
| 75 | + claims := ident.GetClaims() |
| 76 | + if claims != nil { |
| 77 | + err := claims.UnmarshalTo(entityStruct) |
| 78 | + if err != nil { |
| 79 | + return entityresolutionV2.ResolveEntitiesResponse{}, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err)) |
| 80 | + } |
| 81 | + } |
| 82 | + default: |
| 83 | + retrievedStruct, err := entityToStructPb(ident) |
| 84 | + if err != nil { |
| 85 | + logger.Error("unable to make entity struct", slog.String("error", err.Error())) |
| 86 | + return entityresolutionV2.ResolveEntitiesResponse{}, connect.NewError(connect.CodeInternal, fmt.Errorf("unable to make entity struct: %w", err)) |
| 87 | + } |
| 88 | + entityStruct = retrievedStruct |
| 89 | + } |
| 90 | + // make sure the id field is populated |
| 91 | + originialID := ident.GetEphemeralId() |
| 92 | + if originialID == "" { |
| 93 | + originialID = auth.EntityIDPrefix + strconv.Itoa(idx) |
| 94 | + } |
| 95 | + resolvedEntities = append( |
| 96 | + resolvedEntities, |
| 97 | + &entityresolutionV2.EntityRepresentation{ |
| 98 | + OriginalId: originialID, |
| 99 | + AdditionalProps: []*structpb.Struct{entityStruct}, |
| 100 | + }, |
| 101 | + ) |
| 102 | + } |
| 103 | + return entityresolutionV2.ResolveEntitiesResponse{EntityRepresentations: resolvedEntities}, nil |
| 104 | +} |
| 105 | + |
| 106 | +func getEntitiesFromToken(jwtString string) ([]*entity.Entity, error) { |
| 107 | + token, err := jwt.ParseString(jwtString, jwt.WithVerify(false), jwt.WithValidate(false)) |
| 108 | + if err != nil { |
| 109 | + return nil, fmt.Errorf("error parsing jwt: %w", err) |
| 110 | + } |
| 111 | + |
| 112 | + claims := token.PrivateClaims() |
| 113 | + entities := []*entity.Entity{} |
| 114 | + |
| 115 | + // Convert map[string]interface{} to *structpb.Struct |
| 116 | + structClaims, err := structpb.NewStruct(claims) |
| 117 | + if err != nil { |
| 118 | + return nil, fmt.Errorf("error converting to structpb.Struct: %w", err) |
| 119 | + } |
| 120 | + |
| 121 | + // Wrap the struct in an *anypb.Any message |
| 122 | + anyClaims, err := anypb.New(structClaims) |
| 123 | + if err != nil { |
| 124 | + return nil, fmt.Errorf("error wrapping in anypb.Any: %w", err) |
| 125 | + } |
| 126 | + |
| 127 | + entities = append(entities, &entity.Entity{ |
| 128 | + EntityType: &entity.Entity_Claims{Claims: anyClaims}, |
| 129 | + EphemeralId: "jwtentity-claims", |
| 130 | + Category: entity.Entity_CATEGORY_SUBJECT, |
| 131 | + }) |
| 132 | + return entities, nil |
| 133 | +} |
| 134 | + |
| 135 | +func entityToStructPb(ident *entity.Entity) (*structpb.Struct, error) { |
| 136 | + entityBytes, err := protojson.Marshal(ident) |
| 137 | + if err != nil { |
| 138 | + return nil, err |
| 139 | + } |
| 140 | + var entityStruct structpb.Struct |
| 141 | + err = entityStruct.UnmarshalJSON(entityBytes) |
| 142 | + if err != nil { |
| 143 | + return nil, err |
| 144 | + } |
| 145 | + return &entityStruct, nil |
| 146 | +} |
0 commit comments