Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/Configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ server:

## Dot notation is used to access the groups claim
group_claim: "realm_access.roles"

# Dot notation is used to access the claim the represents the idP client ID
client_id_claim: # azp

## Deprecated: Use standard casbin policy groupings (g, <user/group>, <role>)
## Maps the external role to the OpenTDF role
Expand Down
2 changes: 2 additions & 0 deletions opentdf-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ server:
username_claim: # preferred_username
# That claim to access groups (i.e. realm_access.roles)
groups_claim: # realm_access.roles
# Claim the represents the idP client ID
client_id_claim: # azp
## Extends the builtin policy
extension: |
g, opentdf-admin, role:admin
Expand Down
2 changes: 2 additions & 0 deletions opentdf-ers-mode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ server:
default: #"role:standard"
## Dot notation is used to access nested claims (i.e. realm_access.roles)
claim: # realm_access.roles
# Claim the represents the idP client ID
client_id_claim: # azp
## Maps the external role to the opentdf role
## Note: left side is used in the policy, right side is the external role
map:
Expand Down
2 changes: 2 additions & 0 deletions opentdf-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ server:
username_claim: # preferred_username
# That claim to access groups (i.e. realm_access.roles)
groups_claim: # realm_access.roles
# Claim the represents the idP client ID
client_id_claim: # azp
## Extends the builtin policy
extension: |
g, opentdf-admin, role:admin
Expand Down
2 changes: 2 additions & 0 deletions opentdf-kas-mode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ server:
default: #"role:standard"
## Dot notation is used to access nested claims (i.e. realm_access.roles)
claim: # realm_access.roles
# Claim the represents the idP client ID
client_id_claim: # azp
## Maps the external role to the opentdf role
## Note: left side is used in the policy, right side is the external role
map:
Expand Down
104 changes: 82 additions & 22 deletions service/internal/auth/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"google.golang.org/grpc/metadata"

sdkAudit "github.com/opentdf/platform/sdk/audit"
"github.com/opentdf/platform/service/logger"
Expand Down Expand Up @@ -62,6 +61,11 @@ var (
jwa.PS384: true,
jwa.PS512: true,
}

// Exported error variables for client ID processing
ErrClientIDClaimNotConfigured = errors.New("no client ID claim configured")
ErrClientIDClaimNotFound = errors.New("client ID claim not found")
ErrClientIDClaimNotString = errors.New("client ID claim is not a string")
)

const (
Expand Down Expand Up @@ -164,7 +168,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we

// Try an register oidc issuer to wellknown service but don't return an error if it fails
if err := wellknownRegistration("platform_issuer", cfg.Issuer); err != nil {
logger.Warn("failed to register platform issuer", slog.String("error", err.Error()))
logger.Warn("failed to register platform issuer", slog.Any("error", err))
}

var oidcConfigMap map[string]any
Expand All @@ -180,7 +184,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we
}

if err := wellknownRegistration("idp", oidcConfigMap); err != nil {
logger.Warn("failed to register platform idp information", slog.String("error", err.Error()))
logger.Warn("failed to register platform idp information", slog.Any("error", err))
}

return a, nil
Expand Down Expand Up @@ -212,6 +216,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
}

dp := r.Header.Values("Dpop")
log := a.logger

// Verify the token
header := r.Header["Authorization"]
Expand All @@ -228,12 +233,12 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
origin = "http://" + strings.TrimSuffix(origin, ":80")
}
}
accessTok, ctxWithJWK, err := a.checkToken(r.Context(), header, receiverInfo{
accessTok, ctx, err := a.checkToken(r.Context(), header, receiverInfo{
u: []string{normalizeURL(origin, r.URL)},
m: []string{r.Method},
}, dp)
if err != nil {
slog.WarnContext(r.Context(),
log.WarnContext(ctx,
"unauthenticated",
slog.Any("error", err),
slog.Any("dpop", dp),
Expand All @@ -242,12 +247,19 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
return
}

md, ok := metadata.FromIncomingContext(ctxWithJWK)
if !ok {
md = metadata.New(nil)
clientID, err := a.getClientIDFromToken(ctx, accessTok)
if err != nil {
log.WarnContext(
ctx,
"could not determine client ID from token",
slog.Any("err", err),
)
} else {
log = log.
With("client_id", clientID).
With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim)
ctx = ctxAuth.ContextWithAuthnMetadata(ctx, clientID)
}
md.Append("access_token", ctxAuth.GetRawAccessTokenFromContext(ctxWithJWK, nil))
ctxWithJWK = metadata.NewIncomingContext(ctxWithJWK, md)

// Check if the token is allowed to access the resource
var action string
Expand All @@ -263,7 +275,8 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
}
if allow, err := a.enforcer.Enforce(accessTok, r.URL.Path, action); err != nil {
if err.Error() == "permission denied" {
a.logger.WarnContext(r.Context(),
log.WarnContext(
ctx,
"permission denied",
slog.String("azp", accessTok.Subject()),
slog.Any("error", err),
Expand All @@ -274,12 +287,16 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
} else if !allow {
a.logger.WarnContext(r.Context(), "permission denied", slog.String("azp", accessTok.Subject()))
log.WarnContext(
ctx,
"permission denied",
slog.String("azp", accessTok.Subject()),
)
http.Error(w, "permission denied", http.StatusForbidden)
return
}

r = r.WithContext(ctxWithJWK)
r = r.WithContext(ctx)
handler.ServeHTTP(w, r)
})
}
Expand All @@ -296,6 +313,8 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor
return next(ctx, req)
}

log := a.logger

ri := receiverInfo{
u: []string{req.Spec().Procedure},
m: []string{http.MethodPost},
Expand All @@ -319,7 +338,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor
resource := p[1] + "/" + p[2]
action := getAction(p[2])

token, newCtx, err := a.checkToken(
token, ctxWithJWK, err := a.checkToken(
ctx,
header,
ri,
Expand All @@ -329,22 +348,38 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated"))
}

clientID, err := a.getClientIDFromToken(ctxWithJWK, token)
if err != nil {
log.WarnContext(
ctxWithJWK,
"could not determine client ID from token",
slog.Any("err", err),
)
} else {
log = log.
With("client_id", clientID).
With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim)
ctxWithJWK = ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID)
}

// Check if the token is allowed to access the resource
if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil {
if err.Error() == "permission denied" {
a.logger.Warn("permission denied",
log.WarnContext(
ctxWithJWK,
"permission denied",
slog.String("azp", token.Subject()),
slog.Any("error", err),
)
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}
return nil, err
} else if !allowed {
a.logger.Warn("permission denied", slog.String("azp", token.Subject()))
log.WarnContext(ctxWithJWK, "permission denied", slog.String("azp", token.Subject()))
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}

return next(newCtx, req)
return next(ctxWithJWK, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
Expand Down Expand Up @@ -399,7 +434,7 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp
case strings.HasPrefix(authHeader[0], "Bearer "):
tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ")
default:
a.logger.Warn("failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0]))
a.logger.WarnContext(ctx, "failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0]))
return nil, nil, errors.New("not of type bearer or dpop")
}

Expand Down Expand Up @@ -431,12 +466,12 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp
ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw)
return accessToken, ctx, nil
}
key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
dpopKey, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
if err != nil {
a.logger.Warn("failed to validate dpop", slog.Any("err", err))
return nil, nil, err
}
ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw)
ctx = ctxAuth.ContextWithAuthNInfo(ctx, dpopKey, accessToken, tokenRaw)
return accessToken, ctx, nil
}

Expand Down Expand Up @@ -668,7 +703,7 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header
u = append(u, a.lookupGatewayPaths(ctx, path, header)...)

// Validate the token and create a JWT token
_, nextCtx, err := a.checkToken(ctx, authHeader, receiverInfo{
token, ctxWithJWK, err := a.checkToken(ctx, authHeader, receiverInfo{
u: u,
m: []string{http.MethodPost},
}, header["Dpop"])
Expand All @@ -677,8 +712,33 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header
}

// Return the next context with the token
return nextCtx, nil
clientID, err := a.getClientIDFromToken(ctxWithJWK, token)
if err != nil {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated"))
}
return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil
}
}
return ctx, nil
}

// getClientIDFromToken returns the client ID from the token if found (dot notation)
func (a *Authentication) getClientIDFromToken(ctx context.Context, tok jwt.Token) (string, error) {
clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim
if clientIDClaim == "" {
return "", ErrClientIDClaimNotConfigured
}
claimsMap, err := tok.AsMap(ctx)
if err != nil {
return "", fmt.Errorf("failed to parse token as a map and find claim at [%s]: %w", clientIDClaim, err)
}
found := dotNotation(claimsMap, clientIDClaim)
if found == nil {
return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotFound, clientIDClaim)
}
clientID, isString := found.(string)
if !isString {
return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotString, clientIDClaim)
}
return clientID, nil
}
Loading
Loading