diff --git a/docs/Configuring.md b/docs/Configuring.md index 65ab172cc0..3167fa14c9 100644 --- a/docs/Configuring.md +++ b/docs/Configuring.md @@ -352,6 +352,9 @@ server: ## Dot notation is used to access the groups claim group_claim: "realm_access.roles" + + # Dot notation is used to access the claim the represents the idP client ID + client_id_claim: # azp ## Deprecated: Use standard casbin policy groupings (g, , ) ## Maps the external role to the OpenTDF role diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index a0a565af62..3361bd6ec4 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -72,6 +72,8 @@ server: username_claim: # preferred_username # That claim to access groups (i.e. realm_access.roles) groups_claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Extends the builtin policy extension: | g, opentdf-admin, role:admin diff --git a/opentdf-ers-mode.yaml b/opentdf-ers-mode.yaml index 1b0e5f3f7e..a396b963a8 100644 --- a/opentdf-ers-mode.yaml +++ b/opentdf-ers-mode.yaml @@ -28,6 +28,8 @@ server: default: #"role:standard" ## Dot notation is used to access nested claims (i.e. realm_access.roles) claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Maps the external role to the opentdf role ## Note: left side is used in the policy, right side is the external role map: diff --git a/opentdf-example.yaml b/opentdf-example.yaml index 3c012632ba..c110295121 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -53,6 +53,8 @@ server: username_claim: # preferred_username # That claim to access groups (i.e. realm_access.roles) groups_claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Extends the builtin policy extension: | g, opentdf-admin, role:admin diff --git a/opentdf-kas-mode.yaml b/opentdf-kas-mode.yaml index e7532d4e63..cbfaee1f06 100644 --- a/opentdf-kas-mode.yaml +++ b/opentdf-kas-mode.yaml @@ -45,6 +45,8 @@ server: default: #"role:standard" ## Dot notation is used to access nested claims (i.e. realm_access.roles) claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Maps the external role to the opentdf role ## Note: left side is used in the policy, right side is the external role map: diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 5a26aa27e5..cc2e4151eb 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -22,7 +22,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" - "google.golang.org/grpc/metadata" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/logger" @@ -62,6 +61,11 @@ var ( jwa.PS384: true, jwa.PS512: true, } + + // Exported error variables for client ID processing + ErrClientIDClaimNotConfigured = errors.New("no client ID claim configured") + ErrClientIDClaimNotFound = errors.New("client ID claim not found") + ErrClientIDClaimNotString = errors.New("client ID claim is not a string") ) const ( @@ -164,7 +168,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we // Try an register oidc issuer to wellknown service but don't return an error if it fails if err := wellknownRegistration("platform_issuer", cfg.Issuer); err != nil { - logger.Warn("failed to register platform issuer", slog.String("error", err.Error())) + logger.Warn("failed to register platform issuer", slog.Any("error", err)) } var oidcConfigMap map[string]any @@ -180,7 +184,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we } if err := wellknownRegistration("idp", oidcConfigMap); err != nil { - logger.Warn("failed to register platform idp information", slog.String("error", err.Error())) + logger.Warn("failed to register platform idp information", slog.Any("error", err)) } return a, nil @@ -212,6 +216,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { } dp := r.Header.Values("Dpop") + log := a.logger // Verify the token header := r.Header["Authorization"] @@ -228,12 +233,12 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { origin = "http://" + strings.TrimSuffix(origin, ":80") } } - accessTok, ctxWithJWK, err := a.checkToken(r.Context(), header, receiverInfo{ + accessTok, ctx, err := a.checkToken(r.Context(), header, receiverInfo{ u: []string{normalizeURL(origin, r.URL)}, m: []string{r.Method}, }, dp) if err != nil { - slog.WarnContext(r.Context(), + log.WarnContext(ctx, "unauthenticated", slog.Any("error", err), slog.Any("dpop", dp), @@ -242,12 +247,19 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { return } - md, ok := metadata.FromIncomingContext(ctxWithJWK) - if !ok { - md = metadata.New(nil) + clientID, err := a.getClientIDFromToken(ctx, accessTok) + if err != nil { + log.WarnContext( + ctx, + "could not determine client ID from token", + slog.Any("err", err), + ) + } else { + log = log. + With("client_id", clientID). + With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim) + ctx = ctxAuth.ContextWithAuthnMetadata(ctx, clientID) } - md.Append("access_token", ctxAuth.GetRawAccessTokenFromContext(ctxWithJWK, nil)) - ctxWithJWK = metadata.NewIncomingContext(ctxWithJWK, md) // Check if the token is allowed to access the resource var action string @@ -263,7 +275,8 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { } if allow, err := a.enforcer.Enforce(accessTok, r.URL.Path, action); err != nil { if err.Error() == "permission denied" { - a.logger.WarnContext(r.Context(), + log.WarnContext( + ctx, "permission denied", slog.String("azp", accessTok.Subject()), slog.Any("error", err), @@ -274,12 +287,16 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { http.Error(w, "internal server error", http.StatusInternalServerError) return } else if !allow { - a.logger.WarnContext(r.Context(), "permission denied", slog.String("azp", accessTok.Subject())) + log.WarnContext( + ctx, + "permission denied", + slog.String("azp", accessTok.Subject()), + ) http.Error(w, "permission denied", http.StatusForbidden) return } - r = r.WithContext(ctxWithJWK) + r = r.WithContext(ctx) handler.ServeHTTP(w, r) }) } @@ -296,6 +313,8 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return next(ctx, req) } + log := a.logger + ri := receiverInfo{ u: []string{req.Spec().Procedure}, m: []string{http.MethodPost}, @@ -319,7 +338,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor resource := p[1] + "/" + p[2] action := getAction(p[2]) - token, newCtx, err := a.checkToken( + token, ctxWithJWK, err := a.checkToken( ctx, header, ri, @@ -329,10 +348,26 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) } + clientID, err := a.getClientIDFromToken(ctxWithJWK, token) + if err != nil { + log.WarnContext( + ctxWithJWK, + "could not determine client ID from token", + slog.Any("err", err), + ) + } else { + log = log. + With("client_id", clientID). + With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim) + ctxWithJWK = ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + } + // Check if the token is allowed to access the resource if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil { if err.Error() == "permission denied" { - a.logger.Warn("permission denied", + log.WarnContext( + ctxWithJWK, + "permission denied", slog.String("azp", token.Subject()), slog.Any("error", err), ) @@ -340,11 +375,11 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor } return nil, err } else if !allowed { - a.logger.Warn("permission denied", slog.String("azp", token.Subject())) + log.WarnContext(ctxWithJWK, "permission denied", slog.String("azp", token.Subject())) return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) } - return next(newCtx, req) + return next(ctxWithJWK, req) }) } return connect.UnaryInterceptorFunc(interceptor) @@ -399,7 +434,7 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp case strings.HasPrefix(authHeader[0], "Bearer "): tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ") default: - a.logger.Warn("failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0])) + a.logger.WarnContext(ctx, "failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0])) return nil, nil, errors.New("not of type bearer or dpop") } @@ -431,12 +466,12 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw) return accessToken, ctx, nil } - key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader) + dpopKey, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader) if err != nil { a.logger.Warn("failed to validate dpop", slog.Any("err", err)) return nil, nil, err } - ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, dpopKey, accessToken, tokenRaw) return accessToken, ctx, nil } @@ -668,7 +703,7 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header u = append(u, a.lookupGatewayPaths(ctx, path, header)...) // Validate the token and create a JWT token - _, nextCtx, err := a.checkToken(ctx, authHeader, receiverInfo{ + token, ctxWithJWK, err := a.checkToken(ctx, authHeader, receiverInfo{ u: u, m: []string{http.MethodPost}, }, header["Dpop"]) @@ -677,8 +712,33 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header } // Return the next context with the token - return nextCtx, nil + clientID, err := a.getClientIDFromToken(ctxWithJWK, token) + if err != nil { + return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) + } + return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil } } return ctx, nil } + +// getClientIDFromToken returns the client ID from the token if found (dot notation) +func (a *Authentication) getClientIDFromToken(ctx context.Context, tok jwt.Token) (string, error) { + clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim + if clientIDClaim == "" { + return "", ErrClientIDClaimNotConfigured + } + claimsMap, err := tok.AsMap(ctx) + if err != nil { + return "", fmt.Errorf("failed to parse token as a map and find claim at [%s]: %w", clientIDClaim, err) + } + found := dotNotation(claimsMap, clientIDClaim) + if found == nil { + return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotFound, clientIDClaim) + } + clientID, isString := found.(string) + if !isString { + return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotString, clientIDClaim) + } + return clientID, nil +} diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 686bc8ad7d..0883703025 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -56,6 +57,7 @@ type FakeAccessTokenSource struct { } type FakeAccessServiceServer struct { + clientID string accessToken []string dpopKey jwk.Key kas.UnimplementedAccessServiceServer @@ -72,6 +74,7 @@ func (f *FakeAccessServiceServer) LegacyPublicKey(_ context.Context, _ *connect. func (f *FakeAccessServiceServer) Rewrap(ctx context.Context, req *connect.Request[kas.RewrapRequest]) (*connect.Response[kas.RewrapResponse], error) { f.accessToken = req.Header()["Authorization"] f.dpopKey = ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger()) + f.clientID, _ = ctxAuth.GetClientIDFromContext(ctx) return &connect.Response[kas.RewrapResponse]{Msg: &kas.RewrapResponse{}}, nil } @@ -148,7 +151,9 @@ func (s *AuthSuite) SetupTest() { } })) - policyCfg := PolicyConfig{} + policyCfg := PolicyConfig{ + ClientIDClaim: "cid", + } err = defaults.Set(&policyCfg) s.Require().NoError(err) @@ -175,9 +180,7 @@ func (s *AuthSuite) SetupTest() { "/static-doublestar4/x/**", }, }, - &logger.Logger{ - Logger: slog.New(slog.Default().Handler()), - }, + logger.CreateTestLogger(), func(_ string, _ any) error { return nil }, ) @@ -214,6 +217,9 @@ func TestNormalizeUrl(t *testing.T) { func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { // Mock the checkToken method to return a valid token and context mockToken := jwt.New() + err := mockToken.Set("cid", "mockClientID") + s.Require().NoError(err) + type contextKey string mockCtx := context.WithValue(context.Background(), contextKey("mockKey"), "mockValue") s.auth._testCheckTokenFunc = func(_ context.Context, authHeader []string, _ receiverInfo, _ []string) (jwt.Token, context.Context, error) { @@ -234,6 +240,9 @@ func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { s.Require().NoError(err) s.Require().NotNil(nextCtx) s.Equal("mockValue", nextCtx.Value(contextKey("mockKey"))) + clientID, err := ctxAuth.GetClientIDFromContext(nextCtx) + s.Require().NoError(err) + s.Equal("mockClientID", clientID) // Test with a route not requiring reauthorization nextCtx, err = s.auth.ipcReauthCheck(context.Background(), "/kas.AccessService/PublicKey", nil) @@ -254,6 +263,61 @@ func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { s.Contains(err.Error(), "unauthenticated") } +func (s *AuthSuite) Test_ConnectUnaryServerInterceptor_ClientIDPropagated() { + tok := jwt.New() + s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour))) + s.Require().NoError(tok.Set("iss", s.server.URL)) + s.Require().NoError(tok.Set("aud", "test")) + // default client ID claim in policy config is 'azp' + s.Require().NoError(tok.Set("azp", "test-client-id")) + s.Require().NoError(tok.Set("realm_access", map[string][]string{"roles": {"opentdf-standard"}})) + + policyCfg := new(PolicyConfig) + err := defaults.Set(policyCfg) + s.Require().NoError(err) + + authnConfig := AuthNConfig{ + Issuer: s.server.URL, + Audience: "test", + Policy: *policyCfg, + } + config := Config{ + AuthNConfig: authnConfig, + } + auth, err := NewAuthenticator(s.T().Context(), config, logger.CreateTestLogger(), func(_ string, _ any) error { return nil }) + s.Require().NoError(err) + + // Sign the token + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + s.Require().NoError(err) + + // Create a minimal connect server setup to properly test the interceptor + // This is necessary because connect requests need proper procedure routing + interceptor := connect.WithInterceptors(auth.ConnectUnaryServerInterceptor()) + + fakeServer := &FakeAccessServiceServer{} + mux := http.NewServeMux() + path, handler := kasconnect.NewAccessServiceHandler(fakeServer, interceptor) + mux.Handle(path, handler) + + server := memhttp.New(mux) + defer server.Close() + + // Create a connect client that sends a Bearer token + conn, _ := grpc.NewClient("passthrough://bufconn", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return server.Listener.DialContext(ctx, "tcp", "http://localhost:8080") + }), grpc.WithTransportCredentials(insecure.NewCredentials())) + + client := kas.NewAccessServiceClient(conn) + + // Make the request + _, err = client.Rewrap(metadata.AppendToOutgoingContext(s.T().Context(), "authorization", "Bearer "+string(signedTok)), &kas.RewrapRequest{}) + s.Require().NoError(err) + + // Assert that the client ID was properly extracted and set in the context + s.Equal("test-client-id", fakeServer.clientID) +} + func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() { tok := jwt.New() s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC))) @@ -482,7 +546,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour))) s.Require().NoError(tok.Set("iss", s.server.URL)) s.Require().NoError(tok.Set("aud", "test")) - s.Require().NoError(tok.Set("cid", "client2")) + s.Require().NoError(tok.Set("cid", "client-123")) s.Require().NoError(tok.Set("realm_access", map[string][]string{"roles": {"opentdf-standard"}})) thumbprint, err := dpopKey.Thumbprint(crypto.SHA256) s.Require().NoError(err) @@ -517,6 +581,10 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { _, err = client.Rewrap(context.Background(), &kas.RewrapRequest{}) s.Require().NoError(err) + + // interceptor propagated clientID from the token at the configured claim + s.Equal("client-123", fakeServer.clientID) + s.NotNil(fakeServer.dpopKey) dpopJWKFromRequest, ok := fakeServer.dpopKey.(jwk.RSAPublicKey) s.True(ok) @@ -552,12 +620,15 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { jwkChan := make(chan jwk.Key, 1) timeout := make(chan string, 1) + clientIDChan := make(chan string, 1) go func() { time.Sleep(5 * time.Second) timeout <- "" }() server := httptest.NewServer(s.auth.MuxHandler(http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { jwkChan <- ctxAuth.GetJWKFromContext(req.Context(), logger.CreateTestLogger()) + cid, _ := ctxAuth.GetClientIDFromContext(req.Context()) + clientIDChan <- cid }))) defer server.Close() @@ -585,6 +656,15 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { case <-timeout: s.Require().FailNow("timed out waiting for call to complete") } + var clientID string + select { + case cid := <-clientIDChan: + clientID = cid + case <-timeout: + s.Require().FailNow("timed out waiting for call to complete") + } + + s.Equal("client2", clientID) s.NotNil(dpopKeyFromRequest) dpopJWKFromRequest, ok := dpopKeyFromRequest.(jwk.RSAPublicKey) @@ -793,3 +873,123 @@ func (s *AuthSuite) Test_LookupGatewayPaths() { }) } } + +func Test_GetClientIDFromToken(t *testing.T) { + tests := []struct { + name string + claims map[string]interface{} + clientIDClaim string + expectedClientID string + expectedErr error + expectError bool + }{ + { + name: "Happy Path - simple claim", + claims: map[string]interface{}{ + "cid": "test-client-id", + }, + clientIDClaim: "cid", + expectedClientID: "test-client-id", + expectError: false, + }, + { + name: "Happy Path - different claim name", + claims: map[string]interface{}{ + "client": "test-client-id", + }, + clientIDClaim: "client", + expectedClientID: "test-client-id", + expectError: false, + }, + { + name: "Happy Path - dot notation", + claims: map[string]interface{}{ + "client": map[string]interface{}{ + "info": map[string]interface{}{ + "id": "test-client-id", + }, + }, + }, + clientIDClaim: "client.info.id", + expectedClientID: "test-client-id", + expectError: false, + }, + { + name: "Error - no client ID claim configured", + claims: map[string]interface{}{"cid": "test"}, + clientIDClaim: "", // empty claim name + expectedClientID: "", + expectedErr: ErrClientIDClaimNotConfigured, + expectError: true, + }, + { + name: "Error - claim not found", + claims: map[string]interface{}{ + "other-claim": "some-value", + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotFound, + expectError: true, + }, + { + name: "Error - claim is not a string (int)", + claims: map[string]interface{}{ + "cid": 12345, + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotString, + expectError: true, + }, + { + name: "Error - claim is not a string (bool)", + claims: map[string]interface{}{ + "cid": true, + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotString, + expectError: true, + }, + { + name: "Error - claim is not a string (object)", + claims: map[string]interface{}{ + "cid": map[string]interface{}{"nested": "value"}, + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotString, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &Authentication{ + oidcConfiguration: AuthNConfig{ + Policy: PolicyConfig{ + ClientIDClaim: tt.clientIDClaim, + }, + }, + } + + tok := jwt.New() + for k, v := range tt.claims { + err := tok.Set(k, v) + require.NoError(t, err) + } + + clientID, err := auth.getClientIDFromToken(t.Context(), tok) + + assert.Equal(t, tt.expectedClientID, clientID) + + if tt.expectError { + require.Error(t, err) + assert.ErrorIs(t, err, tt.expectedErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/service/internal/auth/config.go b/service/internal/auth/config.go index 7fe8fd9a76..5e48877cf1 100644 --- a/service/internal/auth/config.go +++ b/service/internal/auth/config.go @@ -34,6 +34,8 @@ type PolicyConfig struct { UserNameClaim string `mapstructure:"username_claim" json:"username_claim" default:"preferred_username"` // Claim to use for group/role information GroupsClaim string `mapstructure:"groups_claim" json:"groups_claim" default:"realm_access.roles"` + // Claim to use to reference idP clientID + ClientIDClaim string `mapstructure:"client_id_claim" json:"client_id_claim" default:"azp"` // Deprecated: Use GroupClain instead RoleClaim string `mapstructure:"claim" json:"claim" default:"realm_access.roles"` // Deprecated: Use Casbin grouping statements g, , diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go index 20f5b62d8d..f4d857a777 100644 --- a/service/pkg/auth/context_auth.go +++ b/service/pkg/auth/context_auth.go @@ -2,13 +2,25 @@ package auth import ( "context" + "errors" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/service/logger" + "google.golang.org/grpc/metadata" ) -var authnContextKey = authContextKey{} +var ( + authnContextKey = authContextKey{} + ErrNoMetadataFound = errors.New("no metadata found within context") + ErrMissingClientID = errors.New("context metadata missing authn idP clientID that should have been set by interceptor") + ErrConflictClientID = errors.New("context metadata has more than one authn idP clientID and should only ever have one") +) + +const ( + accessTokenKey = "access_token" + clientIDKey = "client_id" +) type authContextKey struct{} @@ -60,3 +72,46 @@ func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string } return "" } + +// ContextWithAuthnMetadata adds the access token and client ID to context metadata +// +// Adding the authn into to gRPC metadata propagates it across services rather than strictly +// in-process within Go alone +func ContextWithAuthnMetadata(ctx context.Context, clientID string) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(nil) + } else { + // Do not modify original metadata from parent context + md = md.Copy() + } + + if rawToken := GetRawAccessTokenFromContext(ctx, nil); rawToken != "" { + md.Set(accessTokenKey, rawToken) + } + + // Add client ID to metadata for downstream services + if clientID != "" { + md.Set(clientIDKey, clientID) + } + + return metadata.NewIncomingContext(ctx, md) +} + +// GetClientIDFromContext retrieves the client ID from the metadata in the context +func GetClientIDFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", ErrNoMetadataFound + } + + clientIDs := md.Get(clientIDKey) + if len(clientIDs) == 0 { + return "", ErrMissingClientID + } + if len(clientIDs) > 1 { + return "", ErrConflictClientID + } + + return clientIDs[0], nil +} diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index c015fe8270..75b7795e9e 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -8,6 +8,8 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/service/logger" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" ) func TestContextWithAuthNInfo(t *testing.T) { @@ -70,3 +72,90 @@ func TestGetContextDetailsInvalidType(t *testing.T) { retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) assert.Nil(t, retrievedJWK, "JWK should be nil when context value is invalid") } + +func TestContextWithAuthnMetadata(t *testing.T) { + mockClientID := "test-client-id" + + t.Run("should add access token and client id to metadata", func(t *testing.T) { + ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") + enrichedCtx := ContextWithAuthnMetadata(ctx, mockClientID) + + md, ok := metadata.FromIncomingContext(enrichedCtx) + require.True(t, ok) + + accessToken := md.Get("access_token") + require.Len(t, accessToken, 1) + assert.Equal(t, "raw-token-string", accessToken[0]) + + clientIDs := md.Get(clientIDKey) + require.Len(t, clientIDs, 1) + assert.Equal(t, mockClientID, clientIDs[0]) + }) + + t.Run("should not set client id if empty", func(t *testing.T) { + ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") + enrichedCtx := ContextWithAuthnMetadata(ctx, "") + + md, ok := metadata.FromIncomingContext(enrichedCtx) + require.True(t, ok) + + clientIDs := md.Get(clientIDKey) + assert.Empty(t, clientIDs) + }) + + t.Run("should preserve existing metadata", func(t *testing.T) { + originalMD := metadata.New(map[string]string{"original-key": "original-value"}) + ctx := metadata.NewIncomingContext(t.Context(), originalMD) + ctx = ContextWithAuthNInfo(ctx, nil, nil, "raw-token-string") + + enrichedCtx := ContextWithAuthnMetadata(ctx, mockClientID) + + md, ok := metadata.FromIncomingContext(enrichedCtx) + require.True(t, ok) + + originalValue := md.Get("original-key") + require.Len(t, originalValue, 1) + assert.Equal(t, "original-value", originalValue[0]) + + clientIDs := md.Get(clientIDKey) + require.Len(t, clientIDs, 1) + assert.Equal(t, mockClientID, clientIDs[0]) + }) +} + +func TestGetClientIDFromContext(t *testing.T) { + mockClientID := "test-client-id" + + t.Run("good - should retrieve client id from context", func(t *testing.T) { + md := metadata.New(map[string]string{clientIDKey: mockClientID}) + ctx := metadata.NewIncomingContext(t.Context(), md) + + clientID, err := GetClientIDFromContext(ctx) + require.NoError(t, err) + assert.Equal(t, mockClientID, clientID) + }) + + t.Run("bad - should return error if client_id key is not present", func(t *testing.T) { + md := metadata.New(map[string]string{"other-key": "other-value"}) + ctx := metadata.NewIncomingContext(t.Context(), md) + + _, err := GetClientIDFromContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, ErrMissingClientID) + }) + + t.Run("bad - should return error if no metadata in context", func(t *testing.T) { + _, err := GetClientIDFromContext(t.Context()) + require.Error(t, err) + require.ErrorIs(t, err, ErrNoMetadataFound) + }) + + t.Run("bad - should return error if more than one metadata client_id key in context", func(t *testing.T) { + md := metadata.Pairs(clientIDKey, "id-1", clientIDKey, "id-2") + ctx := metadata.NewIncomingContext(t.Context(), md) + + _, err := GetClientIDFromContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, ErrConflictClientID) + }) +}