Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions service/authorization/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ func (as *AuthorizationService) getDecisions(ctx context.Context, dr *authorizat
}
return response, nil
}
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("fqns", strings.Join(allPertinentFQNS.GetAttributeValueFqns(), ", ")))
return nil, db.StatusifyError(ctx, as.logger, err, db.ErrTextGetRetrievalFailed, slog.String("fqns", strings.Join(allPertinentFQNS.GetAttributeValueFqns(), ", ")))
}

var allAttrDefs []*policy.Attribute
Expand Down Expand Up @@ -579,7 +579,7 @@ func (as *AuthorizationService) getDecisions(ctx context.Context, dr *authorizat
ecEntitlements, err := as.GetEntitlements(ctx, &req)
if err != nil {
// TODO: should all decisions in a request fail if one entity entitlement lookup fails?
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("extra", "getEntitlements request failed"))
return nil, db.StatusifyError(ctx, as.logger, err, db.ErrTextGetRetrievalFailed, slog.String("extra", "getEntitlements request failed"))
}
ecChainEntitlementsResponse = append(ecChainEntitlementsResponse, ecEntitlements)
}
Expand Down Expand Up @@ -662,7 +662,7 @@ func (as *AuthorizationService) getDecisions(ctx context.Context, dr *authorizat
)
if err != nil {
// TODO: should all decisions in a request fail if one entity entitlement lookup fails?
return nil, db.StatusifyError(errors.New("could not determine access"), "could not determine access", slog.String("error", err.Error()))
return nil, db.StatusifyError(ctx, as.logger, errors.New("could not determine access"), "could not determine access", slog.String("error", err.Error()))
}
// check the decisions
decision = authorization.DecisionResponse_DECISION_PERMIT
Expand Down
46 changes: 46 additions & 0 deletions service/logger/contextHandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package logger

import (
"context"
"log/slog"

"github.com/google/uuid"
sdkAudit "github.com/opentdf/platform/sdk/audit"
"github.com/opentdf/platform/service/logger/audit"
)

// ContextHandler is a custom slog.Handler that adds context attributes to log records from values set to the
// context by the RPC interceptor. It is used to enrich log records with request-specific metadata such as
// request ID, user agent, request IP, and actor ID.
type ContextHandler struct {
handler slog.Handler
}

// Handle overrides the default Handle method to add context values set by RPC interceptor.
func (h *ContextHandler) Handle(ctx context.Context, r slog.Record) error {
contextData := audit.GetAuditDataFromContext(ctx)

// Only add context attributes if RequestID is present, indicating this is part of a request
if contextData.RequestID != uuid.Nil {
r.AddAttrs(
slog.String(string(sdkAudit.RequestIDContextKey), contextData.RequestID.String()),
slog.String(string(sdkAudit.UserAgentContextKey), contextData.UserAgent),
slog.String(string(sdkAudit.RequestIPContextKey), contextData.RequestIP),
slog.String(string(sdkAudit.ActorIDContextKey), contextData.ActorID),
)
}

return h.handler.Handle(ctx, r)
}

func (h *ContextHandler) Enabled(ctx context.Context, level slog.Level) bool {
return h.handler.Enabled(ctx, level)
}

func (h *ContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &ContextHandler{handler: h.handler.WithAttrs(attrs)}
}

func (h *ContextHandler) WithGroup(name string) slog.Handler {
return &ContextHandler{handler: h.handler.WithGroup(name)}
}
9 changes: 5 additions & 4 deletions service/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,24 @@ func NewLogger(config Config) (*Logger, error) {
return nil, err
}

var handler slog.Handler
switch config.Type {
case "json":
j := slog.NewJSONHandler(w, &slog.HandlerOptions{
handler = slog.NewJSONHandler(w, &slog.HandlerOptions{
Level: level,
ReplaceAttr: logger.replaceAttrChain,
})
sLogger = slog.New(j)
case "text":
t := slog.NewTextHandler(w, &slog.HandlerOptions{
handler = slog.NewTextHandler(w, &slog.HandlerOptions{
Level: level,
ReplaceAttr: logger.replaceAttrChain,
})
sLogger = slog.New(t)
default:
return nil, fmt.Errorf("invalid logger type: %s", config.Type)
}

sLogger = slog.New(&ContextHandler{handler})

// Audit logger will always log at the AUDIT level and be JSON formatted
auditLoggerHandler := slog.NewJSONHandler(w, &slog.HandlerOptions{
Level: audit.LevelAudit,
Expand Down
34 changes: 18 additions & 16 deletions service/pkg/db/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"errors"
"fmt"
"log/slog"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/opentdf/platform/service/logger"
)

var (
Expand Down Expand Up @@ -120,60 +122,60 @@ const (
ErrorTextNamespaceMismatch = "namespace mismatch"
)

func StatusifyError(err error, fallbackErr string, log ...any) error {
l := append([]any{"error", err}, log...)
func StatusifyError(ctx context.Context, l *logger.Logger, err error, fallbackErr string, logs ...any) error {
l = l.With("error", err.Error())
if errors.Is(err, ErrUniqueConstraintViolation) {
slog.Error(ErrTextConflict, l...)
l.ErrorContext(ctx, ErrTextConflict, logs...)
return connect.NewError(connect.CodeAlreadyExists, errors.New(ErrTextConflict))
}
if errors.Is(err, ErrNotFound) {
slog.Error(ErrTextNotFound, l...)
l.ErrorContext(ctx, ErrTextNotFound, logs...)
return connect.NewError(connect.CodeNotFound, errors.New(ErrTextNotFound))
}
if errors.Is(err, ErrForeignKeyViolation) {
slog.Error(ErrTextRelationInvalid, l...)
l.ErrorContext(ctx, ErrTextRelationInvalid, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextRelationInvalid))
}
if errors.Is(err, ErrEnumValueInvalid) {
slog.Error(ErrTextEnumValueInvalid, l...)
l.ErrorContext(ctx, ErrTextEnumValueInvalid, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextEnumValueInvalid))
}
if errors.Is(err, ErrUUIDInvalid) {
slog.Error(ErrTextUUIDInvalid, l...)
l.ErrorContext(ctx, ErrTextUUIDInvalid, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextUUIDInvalid))
}
if errors.Is(err, ErrRestrictViolation) {
slog.Error(ErrTextRestrictViolation, l...)
l.ErrorContext(ctx, ErrTextRestrictViolation, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextRestrictViolation))
}
if errors.Is(err, ErrListLimitTooLarge) {
slog.Error(ErrTextListLimitTooLarge, l...)
l.ErrorContext(ctx, ErrTextListLimitTooLarge, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextListLimitTooLarge))
}
if errors.Is(err, ErrSelectIdentifierInvalid) {
slog.Error(ErrTextInvalidIdentifier, l...)
l.ErrorContext(ctx, ErrTextInvalidIdentifier, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrTextInvalidIdentifier))
}
if errors.Is(err, ErrUnknownSelectIdentifier) {
slog.Error(ErrorTextUnknownIdentifier, l...)
l.ErrorContext(ctx, ErrorTextUnknownIdentifier, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrorTextUnknownIdentifier))
}
if errors.Is(err, ErrCannotUpdateToUnspecified) {
slog.Error(ErrorTextUpdateToUnspecified, l...)
l.ErrorContext(ctx, ErrorTextUpdateToUnspecified, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrorTextUpdateToUnspecified))
}
if errors.Is(err, ErrKeyRotationFailed) {
slog.Error(ErrTextKeyRotationFailed, l...)
l.ErrorContext(ctx, ErrTextKeyRotationFailed, logs...)
return connect.NewError(connect.CodeInternal, errors.New(ErrTextKeyRotationFailed))
}
if errors.Is(err, ErrExpectedBase64EncodedValue) {
slog.Error(ErrorTextExpectedBase64EncodedValue, l...)
l.ErrorContext(ctx, ErrorTextExpectedBase64EncodedValue, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrorTextExpectedBase64EncodedValue))
}
if errors.Is(err, ErrMarshalValueFailed) {
slog.Error(ErrorTextMarshalFailed, l...)
l.ErrorContext(ctx, ErrorTextMarshalFailed, logs...)
return connect.NewError(connect.CodeInvalidArgument, errors.New(ErrorTextMarshalFailed))
}
slog.Error(err.Error(), l...)
l.ErrorContext(ctx, err.Error(), logs...)
return connect.NewError(connect.CodeInternal, errors.New(fallbackErr))
}
2 changes: 1 addition & 1 deletion service/pkg/server/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func startServices(ctx context.Context, cfg *config.Config, otdf *server.OpenTDF
continue
}

var svcLogger *logging.Logger = logger.With("namespace", ns)
svcLogger := logger.With("namespace", ns)
extractedLogLevel, err := extractServiceLoggerConfig(cfg.Services[ns])

// If ns has log_level in config, create new logger with that level
Expand Down
10 changes: 5 additions & 5 deletions service/policy/actions/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (a *ActionService) GetAction(ctx context.Context, req *connect.Request[acti

action, err := a.dbClient.GetAction(ctx, req.Msg)
if err != nil {
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.Any("identifier", req.Msg.GetIdentifier()))
return nil, db.StatusifyError(ctx, a.logger, err, db.ErrTextGetRetrievalFailed, slog.Any("identifier", req.Msg.GetIdentifier()))
}
rsp.Action = action

Expand All @@ -104,7 +104,7 @@ func (a *ActionService) ListActions(ctx context.Context, req *connect.Request[ac
a.logger.DebugContext(ctx, "listing actions")
rsp, err := a.dbClient.ListActions(ctx, req.Msg)
if err != nil {
return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed)
return nil, db.StatusifyError(ctx, a.logger, err, db.ErrTextListRetrievalFailed)
}
a.logger.DebugContext(ctx, "listed actions")
return connect.NewResponse(rsp), nil
Expand Down Expand Up @@ -133,7 +133,7 @@ func (a *ActionService) CreateAction(ctx context.Context, req *connect.Request[a
})
if err != nil {
a.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("action", req.Msg.String()))
return nil, db.StatusifyError(ctx, a.logger, err, db.ErrTextCreationFailed, slog.String("action", req.Msg.String()))
}
return connect.NewResponse(rsp), nil
}
Expand Down Expand Up @@ -173,7 +173,7 @@ func (a *ActionService) UpdateAction(ctx context.Context, req *connect.Request[a
})
if err != nil {
a.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("action", req.Msg.String()))
return nil, db.StatusifyError(ctx, a.logger, err, db.ErrTextUpdateFailed, slog.String("action", req.Msg.String()))
}

return connect.NewResponse(rsp), nil
Expand All @@ -193,7 +193,7 @@ func (a *ActionService) DeleteAction(ctx context.Context, req *connect.Request[a
deleted, err := a.dbClient.DeleteAction(ctx, req.Msg)
if err != nil {
a.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextDeletionFailed, slog.String("action", req.Msg.String()))
return nil, db.StatusifyError(ctx, a.logger, err, db.ErrTextDeletionFailed, slog.String("action", req.Msg.String()))
}

a.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)
Expand Down
Loading
Loading