From 98e6994bbf7df822bb1cebc24bd834d76d8a93b1 Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Thu, 17 Jul 2025 14:49:59 +0200 Subject: [PATCH 01/10] RA-7777: Incorporated logging of timing for a request from the http request to response. This is the logic required on the ctfe to add transaction_id and span_id to the requests and propogate them through to the gRpc calls to help trace a request from end to end (partly implemented in the trillian backend for grpc) --- go.mod | 4 +- trillian/ctfe/ct_server/main.go | 4 +- trillian/ctfe/handlers.go | 14 +++ trillian/ctfe/logging/logger.go | 159 ++++++++++++++++++++++++++++ trillian/ctfe/logging/middleware.go | 33 ++++++ 5 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 trillian/ctfe/logging/logger.go create mode 100644 trillian/ctfe/logging/middleware.go diff --git a/go.mod b/go.mod index 4b7c69b9db..2a689b46de 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang/mock v1.7.0-rc.1 github.com/google/go-cmp v0.7.0 github.com/google/trillian v1.7.2 + github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jackc/pgx/v5 v5.7.5 @@ -19,6 +20,7 @@ require ( github.com/prometheus/client_golang v1.22.0 github.com/rs/cors v1.11.1 github.com/sergi/go-diff v1.4.0 + github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce @@ -71,7 +73,6 @@ require ( github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.5 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gorilla/websocket v1.5.1 // indirect @@ -101,7 +102,6 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/prometheus v0.51.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/stretchr/testify v1.10.0 // indirect diff --git a/trillian/ctfe/ct_server/main.go b/trillian/ctfe/ct_server/main.go index 1410faf29d..472acecaa3 100644 --- a/trillian/ctfe/ct_server/main.go +++ b/trillian/ctfe/ct_server/main.go @@ -24,6 +24,7 @@ import ( "crypto/tls" "flag" "fmt" + "github.com/google/certificate-transparency-go/trillian/ctfe/logging" "net/http" "os" "os/signal" @@ -136,6 +137,7 @@ func main() { } klog.CopyStandardLogTo("WARNING") + klog.Info("Using custom local build of CTFE") klog.Info("**** CT HTTP Server Starting ****") metricsAt := *metricsEndpoint @@ -448,7 +450,7 @@ func setupAndRegister(ctx context.Context, client trillian.TrillianLogClient, de return nil, err } for path, handler := range inst.Handlers { - mux.Handle(lhp+path, handler) + mux.Handle(lhp+path, logging.Middleware(handler)) } return inst, nil } diff --git a/trillian/ctfe/handlers.go b/trillian/ctfe/handlers.go index 23c5894b84..81148999f9 100644 --- a/trillian/ctfe/handlers.go +++ b/trillian/ctfe/handlers.go @@ -32,6 +32,7 @@ import ( "github.com/google/certificate-transparency-go/asn1" "github.com/google/certificate-transparency-go/tls" + "github.com/google/certificate-transparency-go/trillian/ctfe/logging" "github.com/google/certificate-transparency-go/trillian/util" "github.com/google/certificate-transparency-go/x509" "github.com/google/certificate-transparency-go/x509util" @@ -39,6 +40,7 @@ import ( "github.com/google/trillian/monitoring" "github.com/google/trillian/types" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/prototext" "k8s.io/klog/v2" @@ -198,6 +200,18 @@ func (a AppHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Many/most of the handlers forward the request on to the Log RPC server; impose a deadline // on this onward request. ctx, cancel := context.WithDeadline(logCtx, getRPCDeadlineTime(a.Info)) + txID, ok := logCtx.Value(logging.CtxKeyTxID).(string) + if !ok || txID == "" { + klog.Warning("Missing transaction_id in context") + txID = "unknown" + } + spanID, ok := logCtx.Value(logging.CtxKeySpanID).(string) + if !ok || spanID == "" { + klog.Warning("Missing span_id in context") + spanID = "unknown" + } + md := metadata.Pairs("X-Transaction-ID", txID, "X-Span-ID", spanID) + ctx = metadata.NewOutgoingContext(ctx, md) defer cancel() var err error diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go new file mode 100644 index 0000000000..afcb0d39e5 --- /dev/null +++ b/trillian/ctfe/logging/logger.go @@ -0,0 +1,159 @@ +package logging + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type contextKey string + +const ( + CtxKeyTxID contextKey = "transaction_id" + CtxKeySpanID contextKey = "span_id" +) + +var log = logrus.New() + +func init() { + log.Formatter = &logrus.TextFormatter{ + FullTimestamp: true, + } +} + +func generateUUID() string { + return uuid.New().String() +} + +func WithContext(r *http.Request) context.Context { + txID := r.Header.Get("X-Transaction-ID") + if txID == "" { + txID = generateUUID() + } + + spanID := generateUUID() + ctx := context.WithValue(r.Context(), CtxKeyTxID, txID) + ctx = context.WithValue(ctx, CtxKeySpanID, spanID) + return ctx +} + +func WithGRPCContext(ctx context.Context) context.Context { + // First check if there's already a transaction_id in the context (from HTTP) + txID, ok := ctx.Value(CtxKeyTxID).(string) + if !ok || txID == "" { + // If not, try to get it from gRPC metadata + txID = getFromMetadata(ctx, "X-Transaction-ID") + if txID == "" { + txID = generateUUID() + } + } + + // Check for span_id in context first, then metadata + spanID, ok := ctx.Value(CtxKeySpanID).(string) + if !ok || spanID == "" { + spanID = getFromMetadata(ctx, "X-Span-ID") + if spanID == "" { + spanID = generateUUID() + } + } + + ctx = context.WithValue(ctx, CtxKeyTxID, txID) + ctx = context.WithValue(ctx, CtxKeySpanID, spanID) + return ctx +} + +func getFromMetadata(ctx context.Context, key string) string { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "" + } + values := md.Get(key) + if len(values) > 0 { + return values[0] + } + return "" +} + +// PropagateToGRPC adds the transaction_id and span_id from the context to gRPC metadata +// This should be called when making outgoing gRPC calls from HTTP handlers +func PropagateToGRPC(ctx context.Context) context.Context { + txID := ctx.Value(CtxKeyTxID) + spanID := ctx.Value(CtxKeySpanID) + + if txID == nil && spanID == nil { + return ctx + } + + mdMap := make(map[string]string) + if txID != nil { + mdMap["X-Transaction-ID"] = txID.(string) + } + if spanID != nil { + mdMap["X-Span-ID"] = spanID.(string) + } + + md := metadata.New(mdMap) + return metadata.NewOutgoingContext(ctx, md) +} + +func LogWithContext(ctx context.Context, eventID string, msg string, fields map[string]interface{}) { + lf := logrus.Fields{ + "event_id": eventID, + } + + // Safely extract transaction_id and span_id + if txID := ctx.Value(CtxKeyTxID); txID != nil { + lf["transaction_id"] = txID + } + if spanID := ctx.Value(CtxKeySpanID); spanID != nil { + lf["span_id"] = spanID + } + + for k, v := range fields { + lf[k] = v + } + log.WithFields(lf).Info(msg) +} + +func LogTiming(ctx context.Context, r *http.Request, status int, elapsed time.Duration) { + elapsedInMsStr := fmt.Sprintf("%dms", elapsed.Milliseconds()) + LogWithContext(ctx, "timing", "request completed", map[string]interface{}{ + "path": r.URL.Path, + "method": r.Method, + "status": status, + "elapsed": elapsedInMsStr, + }) +} + +func UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (interface{}, error) { + ctx = WithGRPCContext(ctx) + start := time.Now() + resp, err := handler(ctx, req) + elapsed := time.Since(start) + LogWithContext(ctx, "timing", "gRPC call completed", map[string]interface{}{ + "method": info.FullMethod, + "status": statusCodeFromError(err), + "elapsed": elapsed.Milliseconds(), + }) + return resp, err + } +} + +func statusCodeFromError(err error) string { + if err == nil { + return "OK" + } + return err.Error() +} diff --git a/trillian/ctfe/logging/middleware.go b/trillian/ctfe/logging/middleware.go new file mode 100644 index 0000000000..b463f563a3 --- /dev/null +++ b/trillian/ctfe/logging/middleware.go @@ -0,0 +1,33 @@ +package logging + +import ( + "net/http" + "time" +) + +type statusRecorder struct { + http.ResponseWriter + statusCode int +} + +func (r *statusRecorder) WriteHeader(code int) { + r.statusCode = code + r.ResponseWriter.WriteHeader(code) +} + +// this middleware measures how long each HTTP request takes to process. +// It does this by recording the start time before calling the next handler, +// then the end time after the handler finishes. +// It logs the elapsed time, HTTP status code, and request details using LogTiming. +// This helps track request latency for monitoring and debugging. + +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + ctx := WithContext(r) + rw := &statusRecorder{ResponseWriter: w, statusCode: 200} + next.ServeHTTP(rw, r.WithContext(ctx)) + elapsed := time.Since(start) + LogTiming(ctx, r, rw.statusCode, elapsed) + }) +} From 3072f9659321e61dd1c09318409a6e9eacb8a85e Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Fri, 18 Jul 2025 09:17:00 +0200 Subject: [PATCH 02/10] RA-7777: modified logging format to json --- trillian/ctfe/logging/logger.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go index afcb0d39e5..41ef9f540c 100644 --- a/trillian/ctfe/logging/logger.go +++ b/trillian/ctfe/logging/logger.go @@ -22,8 +22,8 @@ const ( var log = logrus.New() func init() { - log.Formatter = &logrus.TextFormatter{ - FullTimestamp: true, + log.Formatter = &logrus.JSONFormatter{ + TimestampFormat: "2006-01-02T15:04:05.000Z07:00", } } From 7f58232d6b801f5806b39a0451c23136471663dd Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Fri, 18 Jul 2025 10:50:16 +0200 Subject: [PATCH 03/10] RA-7777: Added some basic tests for the logger --- trillian/ctfe/logging/logger_test.go | 287 +++++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 trillian/ctfe/logging/logger_test.go diff --git a/trillian/ctfe/logging/logger_test.go b/trillian/ctfe/logging/logger_test.go new file mode 100644 index 0000000000..c375b4fdb0 --- /dev/null +++ b/trillian/ctfe/logging/logger_test.go @@ -0,0 +1,287 @@ +package logging + +import ( + "context" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "google.golang.org/grpc/metadata" +) + +func TestGenerateUUID(t *testing.T) { + // Test that generateUUID returns a non-empty string + uuid1 := generateUUID() + uuid2 := generateUUID() + + if uuid1 == "" { + t.Error("generateUUID returned empty string") + } + + if uuid2 == "" { + t.Error("generateUUID returned empty string") + } + + // Test that two UUIDs are different + if uuid1 == uuid2 { + t.Error("generateUUID returned the same UUID twice, should be unique") + } + + // Test that UUID has expected format (basic check for dashes) + if !strings.Contains(uuid1, "-") { + t.Error("generateUUID didn't return expected UUID format") + } +} + +func TestWithContext(t *testing.T) { + // Create a test HTTP request + req := httptest.NewRequest("GET", "/test", nil) + + // Test case 1: Request without existing transaction ID + ctx := WithContext(req) + + // Check that transaction ID was added + txID := ctx.Value(CtxKeyTxID) + if txID == nil { + t.Error("WithContext didn't add transaction ID to context") + } + + // Check that span ID was added + spanID := ctx.Value(CtxKeySpanID) + if spanID == nil { + t.Error("WithContext didn't add span ID to context") + } + + // Check that both IDs are strings + if _, ok := txID.(string); !ok { + t.Error("Transaction ID is not a string") + } + if _, ok := spanID.(string); !ok { + t.Error("Span ID is not a string") + } +} + +func TestWithContextExistingTransactionID(t *testing.T) { + // Create a test HTTP request with existing transaction ID + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Transaction-ID", "existing-tx-id") + + ctx := WithContext(req) + + // Check that the existing transaction ID was preserved + txID := ctx.Value(CtxKeyTxID) + if txID != "existing-tx-id" { + t.Errorf("Expected transaction ID 'existing-tx-id', got %v", txID) + } + + // Check that a new span ID was still generated + spanID := ctx.Value(CtxKeySpanID) + if spanID == nil { + t.Error("WithContext didn't add span ID to context") + } +} + +func TestWithGRPCContext(t *testing.T) { + // Test case 1: Empty context + ctx := context.Background() + newCtx := WithGRPCContext(ctx) + + txID := newCtx.Value(CtxKeyTxID) + spanID := newCtx.Value(CtxKeySpanID) + + if txID == nil || spanID == nil { + t.Error("WithGRPCContext didn't add IDs to empty context") + } + + // Test case 2: Context with existing values + existingCtx := context.WithValue(context.Background(), CtxKeyTxID, "existing-tx") + existingCtx = context.WithValue(existingCtx, CtxKeySpanID, "existing-span") + + newCtx2 := WithGRPCContext(existingCtx) + + if newCtx2.Value(CtxKeyTxID) != "existing-tx" { + t.Error("WithGRPCContext didn't preserve existing transaction ID") + } + if newCtx2.Value(CtxKeySpanID) != "existing-span" { + t.Error("WithGRPCContext didn't preserve existing span ID") + } +} + +func TestWithGRPCContextFromMetadata(t *testing.T) { + // Create context with gRPC metadata + md := metadata.Pairs("X-Transaction-ID", "metadata-tx-id", "X-Span-ID", "metadata-span-id") + ctx := metadata.NewIncomingContext(context.Background(), md) + + newCtx := WithGRPCContext(ctx) + + txID := newCtx.Value(CtxKeyTxID) + spanID := newCtx.Value(CtxKeySpanID) + + if txID != "metadata-tx-id" { + t.Errorf("Expected transaction ID from metadata 'metadata-tx-id', got %v", txID) + } + if spanID != "metadata-span-id" { + t.Errorf("Expected span ID from metadata 'metadata-span-id', got %v", spanID) + } +} + +func TestLogWithContext(t *testing.T) { + // Create a test hook and attach it to our global logger + hook := test.NewLocal(log) + + // Create context with IDs + ctx := context.WithValue(context.Background(), CtxKeyTxID, "test-tx-id") + ctx = context.WithValue(ctx, CtxKeySpanID, "test-span-id") + + // Test logging + LogWithContext(ctx, "test-event", "test message", map[string]interface{}{ + "custom_field": "custom_value", + }) + + // Check that a log entry was created + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry == nil { + t.Error("LastEntry() returned nil") + return + } + + // Check log level + if entry.Level != logrus.InfoLevel { + t.Errorf("Expected INFO level, got %v", entry.Level) + } + + // Check message + if entry.Message != "test message" { + t.Errorf("Expected message 'test message', got %v", entry.Message) + } + + // Check fields + if entry.Data["event_id"] != "test-event" { + t.Errorf("Expected event_id 'test-event', got %v", entry.Data["event_id"]) + } + if entry.Data["transaction_id"] != "test-tx-id" { + t.Errorf("Expected transaction_id 'test-tx-id', got %v", entry.Data["transaction_id"]) + } + if entry.Data["span_id"] != "test-span-id" { + t.Errorf("Expected span_id 'test-span-id', got %v", entry.Data["span_id"]) + } + if entry.Data["custom_field"] != "custom_value" { + t.Errorf("Expected custom_field 'custom_value', got %v", entry.Data["custom_field"]) + } +} + +func TestLogTiming(t *testing.T) { + // Create a test hook and attach it to our global logger + hook := test.NewLocal(log) + + // Create context with IDs + ctx := context.WithValue(context.Background(), CtxKeyTxID, "timing-tx-id") + ctx = context.WithValue(ctx, CtxKeySpanID, "timing-span-id") + + // Create a test request + req := httptest.NewRequest("GET", "/api/test", nil) + + // Test timing log + elapsed := 250 * time.Millisecond + LogTiming(ctx, req, 200, elapsed) + + // Check that a log entry was created + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry == nil { + t.Error("LastEntry() returned nil") + return + } + + // Check message + if entry.Message != "request completed" { + t.Errorf("Expected message 'request completed', got %v", entry.Message) + } + + // Check timing fields + if entry.Data["event_id"] != "timing" { + t.Error("Expected event_id to be 'timing'") + } + if entry.Data["path"] != "/api/test" { + t.Errorf("Expected path '/api/test', got %v", entry.Data["path"]) + } + if entry.Data["method"] != "GET" { + t.Errorf("Expected method 'GET', got %v", entry.Data["method"]) + } + if entry.Data["status"] != 200 { + t.Errorf("Expected status 200, got %v", entry.Data["status"]) + } + if entry.Data["elapsed"] != "250ms" { + t.Errorf("Expected elapsed '250ms', got %v", entry.Data["elapsed"]) + } +} + +func TestStatusCodeFromError(t *testing.T) { + // Test with no error + status := statusCodeFromError(nil) + if status != "OK" { + t.Errorf("Expected 'OK' for nil error, got %v", status) + } + + // Test with error + err := &testError{msg: "test error"} + status = statusCodeFromError(err) + if status != "test error" { + t.Errorf("Expected 'test error', got %v", status) + } +} + +// Helper type for testing errors +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} + +func TestPropagateToGRPC(t *testing.T) { + // Test with context containing IDs + ctx := context.WithValue(context.Background(), CtxKeyTxID, "propagate-tx-id") + ctx = context.WithValue(ctx, CtxKeySpanID, "propagate-span-id") + + newCtx := PropagateToGRPC(ctx) + + // Extract metadata from the new context + md, ok := metadata.FromOutgoingContext(newCtx) + if !ok { + t.Error("PropagateToGRPC didn't add metadata to context") + } + + // Check that transaction ID was added to metadata + txIDs := md.Get("X-Transaction-ID") + if len(txIDs) != 1 || txIDs[0] != "propagate-tx-id" { + t.Errorf("Expected X-Transaction-ID 'propagate-tx-id', got %v", txIDs) + } + + // Check that span ID was added to metadata + spanIDs := md.Get("X-Span-ID") + if len(spanIDs) != 1 || spanIDs[0] != "propagate-span-id" { + t.Errorf("Expected X-Span-ID 'propagate-span-id', got %v", spanIDs) + } + + // Test with empty context (should return same context) + emptyCtx := context.Background() + sameCtx := PropagateToGRPC(emptyCtx) + + if sameCtx != emptyCtx { + t.Error("PropagateToGRPC should return same context when no IDs present") + } +} From f05fe4dac7a88bd3ca1380528b5ed8d50abea4e0 Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Fri, 18 Jul 2025 15:17:23 +0200 Subject: [PATCH 04/10] RA-7777: Added some tests for the middleware --- trillian/ctfe/logging/middleware_test.go | 328 +++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 trillian/ctfe/logging/middleware_test.go diff --git a/trillian/ctfe/logging/middleware_test.go b/trillian/ctfe/logging/middleware_test.go new file mode 100644 index 0000000000..7b48ade9d0 --- /dev/null +++ b/trillian/ctfe/logging/middleware_test.go @@ -0,0 +1,328 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus/hooks/test" +) + +// TestStatusRecorder tests the statusRecorder type +func TestStatusRecorder(t *testing.T) { + // Create a test ResponseWriter + w := httptest.NewRecorder() + + // Create our statusRecorder wrapper + recorder := &statusRecorder{ + ResponseWriter: w, + statusCode: 200, // default value + } + + // Test that it starts with default status code + if recorder.statusCode != 200 { + t.Errorf("Expected default status code 200, got %d", recorder.statusCode) + } + + // Test WriteHeader method + recorder.WriteHeader(404) + + // Check that our wrapper recorded the status code + if recorder.statusCode != 404 { + t.Errorf("Expected status code 404, got %d", recorder.statusCode) + } + + // Check that the underlying ResponseWriter also got the status code + if w.Code != 404 { + t.Errorf("Expected underlying ResponseWriter code 404, got %d", w.Code) + } +} + +// TestMiddleware tests the main middleware function +func TestMiddleware(t *testing.T) { + // Create a test hook to capture log output + hook := test.NewLocal(log) + + // Create a test handler that we'll wrap with our middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that context was added to the request + txID := r.Context().Value(CtxKeyTxID) + if txID == nil { + t.Error("Request context doesn't contain transaction ID") + } + + spanID := r.Context().Value(CtxKeySpanID) + if spanID == nil { + t.Error("Request context doesn't contain span ID") + } + + // Simulate some work + time.Sleep(10 * time.Millisecond) + + // Write a response + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello World")) + }) + + // Wrap our test handler with the middleware + wrappedHandler := Middleware(testHandler) + + // Create a test request + req := httptest.NewRequest("GET", "/test-endpoint", nil) + + // Create a ResponseRecorder to capture the response + w := httptest.NewRecorder() + + // Execute the request + wrappedHandler.ServeHTTP(w, req) + + // Check that the response was written correctly + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %d", w.Code) + } + + if w.Body.String() != "Hello World" { + t.Errorf("Expected body 'Hello World', got %s", w.Body.String()) + } + + // Check that a log entry was created + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry == nil { + t.Error("LastEntry() returned nil") + return + } + + // Check log message + if entry.Message != "request completed" { + t.Errorf("Expected message 'request completed', got %s", entry.Message) + } + + // Check log fields + if entry.Data["event_id"] != "timing" { + t.Errorf("Expected event_id 'timing', got %v", entry.Data["event_id"]) + } + + if entry.Data["method"] != "GET" { + t.Errorf("Expected method 'GET', got %v", entry.Data["method"]) + } + + if entry.Data["path"] != "/test-endpoint" { + t.Errorf("Expected path '/test-endpoint', got %v", entry.Data["path"]) + } + + if entry.Data["status"] != 200 { + t.Errorf("Expected status 200, got %v", entry.Data["status"]) + } + + // Check that elapsed time was recorded (should be > 0) + elapsed, ok := entry.Data["elapsed"].(string) + if !ok { + t.Error("Elapsed time not recorded as string") + } + + if !strings.HasSuffix(elapsed, "ms") { + t.Errorf("Expected elapsed time to end with 'ms', got %s", elapsed) + } +} + +// TestMiddlewareWithExistingTransactionID tests middleware when request already has transaction ID +func TestMiddlewareWithExistingTransactionID(t *testing.T) { + hook := test.NewLocal(log) + + // Create a test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that the existing transaction ID was preserved + txID := r.Context().Value(CtxKeyTxID) + if txID != "existing-tx-123" { + t.Errorf("Expected existing transaction ID 'existing-tx-123', got %v", txID) + } + + w.WriteHeader(http.StatusOK) + }) + + // Wrap with middleware + wrappedHandler := Middleware(testHandler) + + // Create request with existing transaction ID + req := httptest.NewRequest("POST", "/api/submit", strings.NewReader("test data")) + req.Header.Set("X-Transaction-ID", "existing-tx-123") + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + // Check log entry + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry.Data["transaction_id"] != "existing-tx-123" { + t.Errorf("Expected transaction_id 'existing-tx-123', got %v", entry.Data["transaction_id"]) + } +} + +// TestMiddlewareWithDifferentStatusCodes tests middleware with various HTTP status codes +func TestMiddlewareWithDifferentStatusCodes(t *testing.T) { + testCases := []struct { + name string + statusCode int + expectedStatus int + }{ + {"Success", http.StatusOK, 200}, + {"Created", http.StatusCreated, 201}, + {"BadRequest", http.StatusBadRequest, 400}, + {"NotFound", http.StatusNotFound, 404}, + {"InternalServerError", http.StatusInternalServerError, 500}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hook := test.NewLocal(log) + + // Create handler that returns specific status code + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.statusCode) + }) + + wrappedHandler := Middleware(testHandler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check response status + if w.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code) + } + + // Check log entry + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry.Data["status"] != tc.expectedStatus { + t.Errorf("Expected logged status %d, got %v", tc.expectedStatus, entry.Data["status"]) + } + }) + } +} + +// TestMiddlewareWithPanic tests that middleware handles panics gracefully +func TestMiddlewareWithPanic(t *testing.T) { + // Create handler that panics + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + wrappedHandler := Middleware(testHandler) + + req := httptest.NewRequest("GET", "/panic", nil) + w := httptest.NewRecorder() + + // This should panic, so we need to recover + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic, but none occurred") + } + }() + + wrappedHandler.ServeHTTP(w, req) +} + +// TestMiddlewareTimingAccuracy tests that timing is reasonably accurate +func TestMiddlewareTimingAccuracy(t *testing.T) { + hook := test.NewLocal(log) + + // Create handler that sleeps for a known duration + sleepDuration := 50 * time.Millisecond + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(sleepDuration) + w.WriteHeader(http.StatusOK) + }) + + wrappedHandler := Middleware(testHandler) + + req := httptest.NewRequest("GET", "/slow", nil) + w := httptest.NewRecorder() + + start := time.Now() + wrappedHandler.ServeHTTP(w, req) + actualElapsed := time.Since(start) + + // Check that timing was logged + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + elapsedStr, ok := entry.Data["elapsed"].(string) + if !ok { + t.Error("Elapsed time not recorded as string") + return + } + + // Parse the elapsed time (remove "ms" suffix) + elapsedStr = strings.TrimSuffix(elapsedStr, "ms") + + // Check that the timing is reasonably accurate using proper numeric comparison + // Allow for some variance (±20ms) due to system scheduling and load + tolerance := 20 * time.Millisecond + if actualElapsed < (sleepDuration-tolerance) || actualElapsed > (sleepDuration+tolerance) { + t.Errorf("Expected elapsed time around %v (±%v), got %v. Logged: %sms", + sleepDuration, tolerance, actualElapsed, elapsedStr) + } +} + +// TestMiddlewareChaining tests that multiple middlewares can be chained +func TestMiddlewareChaining(t *testing.T) { + hook := test.NewLocal(log) + + // Create a simple handler + finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("final")) + }) + + // Create another middleware for testing chaining + authMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add a custom header to verify this middleware ran + w.Header().Set("X-Auth-Middleware", "ran") + next.ServeHTTP(w, r) + }) + } + + // Chain middlewares: authMiddleware -> Middleware -> finalHandler + chainedHandler := authMiddleware(Middleware(finalHandler)) + + req := httptest.NewRequest("GET", "/chained", nil) + w := httptest.NewRecorder() + + chainedHandler.ServeHTTP(w, req) + + // Check that both middlewares ran + if w.Header().Get("X-Auth-Middleware") != "ran" { + t.Error("Auth middleware didn't run") + } + + if w.Body.String() != "final" { + t.Errorf("Expected body 'final', got %s", w.Body.String()) + } + + // Check that our logging middleware created a log entry + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + } +} From 45fc8bc5be25c5993b8b03fb8f1452a348845213 Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Tue, 22 Jul 2025 15:08:35 +0200 Subject: [PATCH 05/10] Remove custom message. It was used for testing purposes only --- trillian/ctfe/ct_server/main.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trillian/ctfe/ct_server/main.go b/trillian/ctfe/ct_server/main.go index 472acecaa3..70b3d9ff73 100644 --- a/trillian/ctfe/ct_server/main.go +++ b/trillian/ctfe/ct_server/main.go @@ -24,7 +24,6 @@ import ( "crypto/tls" "flag" "fmt" - "github.com/google/certificate-transparency-go/trillian/ctfe/logging" "net/http" "os" "os/signal" @@ -34,6 +33,8 @@ import ( "syscall" "time" + "github.com/google/certificate-transparency-go/trillian/ctfe/logging" + "github.com/google/certificate-transparency-go/trillian/ctfe" "github.com/google/certificate-transparency-go/trillian/ctfe/cache" "github.com/google/certificate-transparency-go/trillian/ctfe/configpb" @@ -137,7 +138,7 @@ func main() { } klog.CopyStandardLogTo("WARNING") - klog.Info("Using custom local build of CTFE") + klog.Info("**** CT HTTP Server Starting ****") metricsAt := *metricsEndpoint From 195347eabca9427bfc4c65fed6636f35bef87c4e Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Tue, 22 Jul 2025 20:29:37 +0200 Subject: [PATCH 06/10] Span id should not be propagated --- trillian/ctfe/logging/logger.go | 26 +++++++++++--------------- trillian/ctfe/logging/logger_test.go | 14 +++++++++----- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go index 41ef9f540c..703a1029ad 100644 --- a/trillian/ctfe/logging/logger.go +++ b/trillian/ctfe/logging/logger.go @@ -54,13 +54,12 @@ func WithGRPCContext(ctx context.Context) context.Context { } } - // Check for span_id in context first, then metadata + // Check for span_id in context first spanID, ok := ctx.Value(CtxKeySpanID).(string) if !ok || spanID == "" { - spanID = getFromMetadata(ctx, "X-Span-ID") - if spanID == "" { - spanID = generateUUID() - } + // Always generate a new span_id for this service + // No longer checking metadata since span_id is not propagated + spanID = generateUUID() } ctx = context.WithValue(ctx, CtxKeyTxID, txID) @@ -80,23 +79,20 @@ func getFromMetadata(ctx context.Context, key string) string { return "" } -// PropagateToGRPC adds the transaction_id and span_id from the context to gRPC metadata -// This should be called when making outgoing gRPC calls from HTTP handlers +// PropagateToGRPC adds the transaction_id from the context to gRPC metadata +// This ensures transaction correlation across service boundaries +// Note: span_id is NOT propagated - each service generates its own span_id func PropagateToGRPC(ctx context.Context) context.Context { txID := ctx.Value(CtxKeyTxID) - spanID := ctx.Value(CtxKeySpanID) - if txID == nil && spanID == nil { + if txID == nil { return ctx } - mdMap := make(map[string]string) - if txID != nil { - mdMap["X-Transaction-ID"] = txID.(string) - } - if spanID != nil { - mdMap["X-Span-ID"] = spanID.(string) + mdMap := map[string]string{ + "X-Transaction-ID": txID.(string), } + // Deliberately NOT propagating span_id - each service gets its own md := metadata.New(mdMap) return metadata.NewOutgoingContext(ctx, md) diff --git a/trillian/ctfe/logging/logger_test.go b/trillian/ctfe/logging/logger_test.go index c375b4fdb0..32c3a910eb 100644 --- a/trillian/ctfe/logging/logger_test.go +++ b/trillian/ctfe/logging/logger_test.go @@ -123,8 +123,12 @@ func TestWithGRPCContextFromMetadata(t *testing.T) { if txID != "metadata-tx-id" { t.Errorf("Expected transaction ID from metadata 'metadata-tx-id', got %v", txID) } - if spanID != "metadata-span-id" { - t.Errorf("Expected span ID from metadata 'metadata-span-id', got %v", spanID) + // Span ID should be newly generated, not from metadata (each service gets its own span) + if spanID == "metadata-span-id" { + t.Errorf("Expected new span ID to be generated, but got metadata span ID: %v", spanID) + } + if spanID == "" { + t.Error("Expected new span ID to be generated, but got empty string") } } @@ -271,10 +275,10 @@ func TestPropagateToGRPC(t *testing.T) { t.Errorf("Expected X-Transaction-ID 'propagate-tx-id', got %v", txIDs) } - // Check that span ID was added to metadata + // Check that span ID was NOT added to metadata (new design) spanIDs := md.Get("X-Span-ID") - if len(spanIDs) != 1 || spanIDs[0] != "propagate-span-id" { - t.Errorf("Expected X-Span-ID 'propagate-span-id', got %v", spanIDs) + if len(spanIDs) != 0 { + t.Errorf("Expected no X-Span-ID in metadata, but got %v", spanIDs) } // Test with empty context (should return same context) From 72f2549dd08f8ab1d19ff2da4bd21ebc08b1b243 Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Fri, 25 Jul 2025 12:03:03 +0200 Subject: [PATCH 07/10] RA-7777: Fix a codeql issue. Sanitize the log message to prevent possible log injection. --- trillian/ctfe/logging/logger.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go index 703a1029ad..d2f1e3c5d6 100644 --- a/trillian/ctfe/logging/logger.go +++ b/trillian/ctfe/logging/logger.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "time" "github.com/google/uuid" @@ -31,6 +32,13 @@ func generateUUID() string { return uuid.New().String() } +// sanitizeLogMessage removes or replaces characters that could be used for log injection +func sanitizeLogMessage(msg string) string { + msg = strings.ReplaceAll(msg, "\n", "\\n") + msg = strings.ReplaceAll(msg, "\r", "\\r") + return msg +} + func WithContext(r *http.Request) context.Context { txID := r.Header.Get("X-Transaction-ID") if txID == "" { @@ -114,7 +122,9 @@ func LogWithContext(ctx context.Context, eventID string, msg string, fields map[ for k, v := range fields { lf[k] = v } - log.WithFields(lf).Info(msg) + // Sanitize the message to prevent log injection + sanitizedMsg := sanitizeLogMessage(msg) + log.WithFields(lf).Info(sanitizedMsg) } func LogTiming(ctx context.Context, r *http.Request, status int, elapsed time.Duration) { From 4e7add6737b4294de0b90cf8b5057e5daabb6925 Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Fri, 25 Jul 2025 15:05:35 +0200 Subject: [PATCH 08/10] RA-7777: Sanitize the transaction_id to some degree of confidence --- trillian/ctfe/logging/logger.go | 31 +++-- trillian/ctfe/logging/logger_test.go | 167 +++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 11 deletions(-) diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go index d2f1e3c5d6..49736a7459 100644 --- a/trillian/ctfe/logging/logger.go +++ b/trillian/ctfe/logging/logger.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "strings" "time" "github.com/google/uuid" @@ -32,16 +31,25 @@ func generateUUID() string { return uuid.New().String() } -// sanitizeLogMessage removes or replaces characters that could be used for log injection -func sanitizeLogMessage(msg string) string { - msg = strings.ReplaceAll(msg, "\n", "\\n") - msg = strings.ReplaceAll(msg, "\r", "\\r") - return msg +// isReasonableTransactionID performs basic validation for propagation safety +// Rejects empty, too long, or obviously malicious values +func isReasonableTransactionID(txID string) bool { + if txID == "" || len(txID) > 128 { + return false + } + // Reject values with control characters that could cause issues + for _, r := range txID { + if r < 32 || r == 127 { + return false + } + } + return true } func WithContext(r *http.Request) context.Context { txID := r.Header.Get("X-Transaction-ID") - if txID == "" { + if txID == "" || !isReasonableTransactionID(txID) { + // If no header or user-provided txID looks suspicious, generate a new one txID = generateUUID() } @@ -57,7 +65,8 @@ func WithGRPCContext(ctx context.Context) context.Context { if !ok || txID == "" { // If not, try to get it from gRPC metadata txID = getFromMetadata(ctx, "X-Transaction-ID") - if txID == "" { + if txID == "" || !isReasonableTransactionID(txID) { + // If no metadata or gRPC metadata txID looks suspicious, generate a new one txID = generateUUID() } } @@ -90,6 +99,7 @@ func getFromMetadata(ctx context.Context, key string) string { // PropagateToGRPC adds the transaction_id from the context to gRPC metadata // This ensures transaction correlation across service boundaries // Note: span_id is NOT propagated - each service generates its own span_id +// Note: transaction_id values are validated at input to prevent propagation of malicious values func PropagateToGRPC(ctx context.Context) context.Context { txID := ctx.Value(CtxKeyTxID) @@ -122,9 +132,8 @@ func LogWithContext(ctx context.Context, eventID string, msg string, fields map[ for k, v := range fields { lf[k] = v } - // Sanitize the message to prevent log injection - sanitizedMsg := sanitizeLogMessage(msg) - log.WithFields(lf).Info(sanitizedMsg) + // JSONFormatter automatically handles escaping, preventing log injection + log.WithFields(lf).Info(msg) } func LogTiming(ctx context.Context, r *http.Request, status int, elapsed time.Duration) { diff --git a/trillian/ctfe/logging/logger_test.go b/trillian/ctfe/logging/logger_test.go index 32c3a910eb..50ae56dc7e 100644 --- a/trillian/ctfe/logging/logger_test.go +++ b/trillian/ctfe/logging/logger_test.go @@ -289,3 +289,170 @@ func TestPropagateToGRPC(t *testing.T) { t.Error("PropagateToGRPC should return same context when no IDs present") } } + +func TestIsReasonableTransactionID(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + // Valid cases + {"Standard UUID", "550e8400-e29b-41d4-a716-446655440000", true}, + {"Alphanumeric", "abc123XYZ", true}, + {"With hyphens", "test-transaction-123", true}, + {"With underscores", "test_transaction_123", true}, + {"Mixed symbols", "tx.123@domain.com", true}, + {"Unicode", "测试-transaction-値", true}, + {"Numbers only", "123456789", true}, + {"Letters only", "abcdefghijk", true}, + {"With spaces", "uuid with spaces", true}, + {"Special chars", "uuid!@#$%^&*()+={}[]|\\:;\"'<>,.?/", true}, + {"Exactly 128 chars", strings.Repeat("a", 128), true}, + + // Invalid cases - control characters + {"With newline", "uuid\nmalicious", false}, + {"With carriage return", "uuid\rmalicious", false}, + {"With tab", "uuid\tmalicious", false}, + {"With null byte", "uuid\x00malicious", false}, + {"With escape", "uuid\x1bmalicious", false}, + {"With bell", "uuid\x07malicious", false}, + {"With DEL", "uuid\x7fmalicious", false}, + {"With backspace", "uuid\x08malicious", false}, + {"With form feed", "uuid\x0cmalicious", false}, + + // Edge cases + {"Empty string", "", false}, + {"Too long", strings.Repeat("a", 129), false}, + {"Only control chars", "\n\r\t", false}, + {"Mixed control and valid", "valid\x00invalid", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isReasonableTransactionID(tt.input) + if result != tt.expected { + t.Errorf("isReasonableTransactionID(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestWithContextSanitization(t *testing.T) { + tests := []struct { + name string + headerValue string + expectGenerated bool + }{ + {"Valid UUID", "550e8400-e29b-41d4-a716-446655440000", false}, + {"Valid custom ID", "user-session-123", false}, + {"Malicious with newline", "uuid\nmalicious", true}, + {"Malicious with null", "uuid\x00malicious", true}, + {"Too long", strings.Repeat("a", 129), true}, + {"Empty header", "", true}, + {"Only control chars", "\n\r\t\x00", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + if tt.headerValue != "" { + req.Header.Set("X-Transaction-ID", tt.headerValue) + } + + ctx := WithContext(req) + txID := ctx.Value(CtxKeyTxID).(string) + + if tt.expectGenerated { + // Should be a generated UUID, not the original value + if txID == tt.headerValue { + t.Errorf("Expected generated UUID but got original value: %q", txID) + } + // Check it looks like a UUID (36 chars with hyphens) + if len(txID) != 36 || !strings.Contains(txID, "-") { + t.Errorf("Expected generated UUID format but got: %q", txID) + } + } else { + // Should be the original value + if txID != tt.headerValue { + t.Errorf("Expected original value %q but got %q", tt.headerValue, txID) + } + } + }) + } +} + +func TestWithGRPCContextSanitization(t *testing.T) { + tests := []struct { + name string + metadataValue string + expectGenerated bool + }{ + {"Valid UUID", "550e8400-e29b-41d4-a716-446655440000", false}, + {"Valid custom ID", "grpc-session-456", false}, + {"Malicious with carriage return", "uuid\rmalicious", true}, + {"Malicious with escape", "uuid\x1bmalicious", true}, + {"Too long", strings.Repeat("b", 129), true}, + {"Control characters", "\x01\x02\x03", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create context with gRPC metadata + md := metadata.New(map[string]string{ + "X-Transaction-ID": tt.metadataValue, + }) + ctx := metadata.NewIncomingContext(context.Background(), md) + + resultCtx := WithGRPCContext(ctx) + txID := resultCtx.Value(CtxKeyTxID).(string) + + if tt.expectGenerated { + // Should be a generated UUID, not the original value + if txID == tt.metadataValue { + t.Errorf("Expected generated UUID but got original value: %q", txID) + } + // Check it looks like a UUID + if len(txID) != 36 || !strings.Contains(txID, "-") { + t.Errorf("Expected generated UUID format but got: %q", txID) + } + } else { + // Should be the original value + if txID != tt.metadataValue { + t.Errorf("Expected original value %q but got %q", tt.metadataValue, txID) + } + } + }) + } +} + +func TestSanitizationInLogging(t *testing.T) { + // Setup test logger with hook to capture logs + hook := test.NewLocal(log) + + // Test that malicious transaction IDs don't break JSON logs + ctx := context.Background() + ctx = context.WithValue(ctx, CtxKeyTxID, "malicious\nvalue") + ctx = context.WithValue(ctx, CtxKeySpanID, "span\rvalue") + + LogWithContext(ctx, "test", "test message", map[string]interface{}{ + "field": "value\x00with\x1bnull", + }) + + // Check that log was created (JSONFormatter should handle escaping) + if len(hook.Entries) == 0 { + t.Error("Expected log entry but none was created") + return + } + + entry := hook.Entries[0] + if entry.Message != "test message" { + t.Errorf("Expected message 'test message' but got %q", entry.Message) + } + + // Verify fields are present (JSONFormatter handles escaping) + if entry.Data["transaction_id"] != "malicious\nvalue" { + t.Errorf("Transaction ID not preserved correctly: %v", entry.Data["transaction_id"]) + } + + hook.Reset() +} From 4ccdfd395e38940c97515179498696943636622d Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Sat, 26 Jul 2025 20:05:06 +0200 Subject: [PATCH 09/10] Revert "RA-7777: Sanitize the transaction_id to some degree of confidence" This reverts commit 4e7add6737b4294de0b90cf8b5057e5daabb6925. --- trillian/ctfe/logging/logger.go | 31 ++--- trillian/ctfe/logging/logger_test.go | 167 --------------------------- 2 files changed, 11 insertions(+), 187 deletions(-) diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go index 49736a7459..d2f1e3c5d6 100644 --- a/trillian/ctfe/logging/logger.go +++ b/trillian/ctfe/logging/logger.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "time" "github.com/google/uuid" @@ -31,25 +32,16 @@ func generateUUID() string { return uuid.New().String() } -// isReasonableTransactionID performs basic validation for propagation safety -// Rejects empty, too long, or obviously malicious values -func isReasonableTransactionID(txID string) bool { - if txID == "" || len(txID) > 128 { - return false - } - // Reject values with control characters that could cause issues - for _, r := range txID { - if r < 32 || r == 127 { - return false - } - } - return true +// sanitizeLogMessage removes or replaces characters that could be used for log injection +func sanitizeLogMessage(msg string) string { + msg = strings.ReplaceAll(msg, "\n", "\\n") + msg = strings.ReplaceAll(msg, "\r", "\\r") + return msg } func WithContext(r *http.Request) context.Context { txID := r.Header.Get("X-Transaction-ID") - if txID == "" || !isReasonableTransactionID(txID) { - // If no header or user-provided txID looks suspicious, generate a new one + if txID == "" { txID = generateUUID() } @@ -65,8 +57,7 @@ func WithGRPCContext(ctx context.Context) context.Context { if !ok || txID == "" { // If not, try to get it from gRPC metadata txID = getFromMetadata(ctx, "X-Transaction-ID") - if txID == "" || !isReasonableTransactionID(txID) { - // If no metadata or gRPC metadata txID looks suspicious, generate a new one + if txID == "" { txID = generateUUID() } } @@ -99,7 +90,6 @@ func getFromMetadata(ctx context.Context, key string) string { // PropagateToGRPC adds the transaction_id from the context to gRPC metadata // This ensures transaction correlation across service boundaries // Note: span_id is NOT propagated - each service generates its own span_id -// Note: transaction_id values are validated at input to prevent propagation of malicious values func PropagateToGRPC(ctx context.Context) context.Context { txID := ctx.Value(CtxKeyTxID) @@ -132,8 +122,9 @@ func LogWithContext(ctx context.Context, eventID string, msg string, fields map[ for k, v := range fields { lf[k] = v } - // JSONFormatter automatically handles escaping, preventing log injection - log.WithFields(lf).Info(msg) + // Sanitize the message to prevent log injection + sanitizedMsg := sanitizeLogMessage(msg) + log.WithFields(lf).Info(sanitizedMsg) } func LogTiming(ctx context.Context, r *http.Request, status int, elapsed time.Duration) { diff --git a/trillian/ctfe/logging/logger_test.go b/trillian/ctfe/logging/logger_test.go index 50ae56dc7e..32c3a910eb 100644 --- a/trillian/ctfe/logging/logger_test.go +++ b/trillian/ctfe/logging/logger_test.go @@ -289,170 +289,3 @@ func TestPropagateToGRPC(t *testing.T) { t.Error("PropagateToGRPC should return same context when no IDs present") } } - -func TestIsReasonableTransactionID(t *testing.T) { - tests := []struct { - name string - input string - expected bool - }{ - // Valid cases - {"Standard UUID", "550e8400-e29b-41d4-a716-446655440000", true}, - {"Alphanumeric", "abc123XYZ", true}, - {"With hyphens", "test-transaction-123", true}, - {"With underscores", "test_transaction_123", true}, - {"Mixed symbols", "tx.123@domain.com", true}, - {"Unicode", "测试-transaction-値", true}, - {"Numbers only", "123456789", true}, - {"Letters only", "abcdefghijk", true}, - {"With spaces", "uuid with spaces", true}, - {"Special chars", "uuid!@#$%^&*()+={}[]|\\:;\"'<>,.?/", true}, - {"Exactly 128 chars", strings.Repeat("a", 128), true}, - - // Invalid cases - control characters - {"With newline", "uuid\nmalicious", false}, - {"With carriage return", "uuid\rmalicious", false}, - {"With tab", "uuid\tmalicious", false}, - {"With null byte", "uuid\x00malicious", false}, - {"With escape", "uuid\x1bmalicious", false}, - {"With bell", "uuid\x07malicious", false}, - {"With DEL", "uuid\x7fmalicious", false}, - {"With backspace", "uuid\x08malicious", false}, - {"With form feed", "uuid\x0cmalicious", false}, - - // Edge cases - {"Empty string", "", false}, - {"Too long", strings.Repeat("a", 129), false}, - {"Only control chars", "\n\r\t", false}, - {"Mixed control and valid", "valid\x00invalid", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := isReasonableTransactionID(tt.input) - if result != tt.expected { - t.Errorf("isReasonableTransactionID(%q) = %v, want %v", tt.input, result, tt.expected) - } - }) - } -} - -func TestWithContextSanitization(t *testing.T) { - tests := []struct { - name string - headerValue string - expectGenerated bool - }{ - {"Valid UUID", "550e8400-e29b-41d4-a716-446655440000", false}, - {"Valid custom ID", "user-session-123", false}, - {"Malicious with newline", "uuid\nmalicious", true}, - {"Malicious with null", "uuid\x00malicious", true}, - {"Too long", strings.Repeat("a", 129), true}, - {"Empty header", "", true}, - {"Only control chars", "\n\r\t\x00", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/test", nil) - if tt.headerValue != "" { - req.Header.Set("X-Transaction-ID", tt.headerValue) - } - - ctx := WithContext(req) - txID := ctx.Value(CtxKeyTxID).(string) - - if tt.expectGenerated { - // Should be a generated UUID, not the original value - if txID == tt.headerValue { - t.Errorf("Expected generated UUID but got original value: %q", txID) - } - // Check it looks like a UUID (36 chars with hyphens) - if len(txID) != 36 || !strings.Contains(txID, "-") { - t.Errorf("Expected generated UUID format but got: %q", txID) - } - } else { - // Should be the original value - if txID != tt.headerValue { - t.Errorf("Expected original value %q but got %q", tt.headerValue, txID) - } - } - }) - } -} - -func TestWithGRPCContextSanitization(t *testing.T) { - tests := []struct { - name string - metadataValue string - expectGenerated bool - }{ - {"Valid UUID", "550e8400-e29b-41d4-a716-446655440000", false}, - {"Valid custom ID", "grpc-session-456", false}, - {"Malicious with carriage return", "uuid\rmalicious", true}, - {"Malicious with escape", "uuid\x1bmalicious", true}, - {"Too long", strings.Repeat("b", 129), true}, - {"Control characters", "\x01\x02\x03", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create context with gRPC metadata - md := metadata.New(map[string]string{ - "X-Transaction-ID": tt.metadataValue, - }) - ctx := metadata.NewIncomingContext(context.Background(), md) - - resultCtx := WithGRPCContext(ctx) - txID := resultCtx.Value(CtxKeyTxID).(string) - - if tt.expectGenerated { - // Should be a generated UUID, not the original value - if txID == tt.metadataValue { - t.Errorf("Expected generated UUID but got original value: %q", txID) - } - // Check it looks like a UUID - if len(txID) != 36 || !strings.Contains(txID, "-") { - t.Errorf("Expected generated UUID format but got: %q", txID) - } - } else { - // Should be the original value - if txID != tt.metadataValue { - t.Errorf("Expected original value %q but got %q", tt.metadataValue, txID) - } - } - }) - } -} - -func TestSanitizationInLogging(t *testing.T) { - // Setup test logger with hook to capture logs - hook := test.NewLocal(log) - - // Test that malicious transaction IDs don't break JSON logs - ctx := context.Background() - ctx = context.WithValue(ctx, CtxKeyTxID, "malicious\nvalue") - ctx = context.WithValue(ctx, CtxKeySpanID, "span\rvalue") - - LogWithContext(ctx, "test", "test message", map[string]interface{}{ - "field": "value\x00with\x1bnull", - }) - - // Check that log was created (JSONFormatter should handle escaping) - if len(hook.Entries) == 0 { - t.Error("Expected log entry but none was created") - return - } - - entry := hook.Entries[0] - if entry.Message != "test message" { - t.Errorf("Expected message 'test message' but got %q", entry.Message) - } - - // Verify fields are present (JSONFormatter handles escaping) - if entry.Data["transaction_id"] != "malicious\nvalue" { - t.Errorf("Transaction ID not preserved correctly: %v", entry.Data["transaction_id"]) - } - - hook.Reset() -} From 7775e667090bfcae8171a03fe4e681c5335ee2c4 Mon Sep 17 00:00:00 2001 From: Himaschal Pursan Date: Sat, 26 Jul 2025 20:05:15 +0200 Subject: [PATCH 10/10] Revert "RA-7777: Fix a codeql issue. Sanitize the log message to prevent possible log injection." This reverts commit 72f2549dd08f8ab1d19ff2da4bd21ebc08b1b243. --- trillian/ctfe/logging/logger.go | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go index d2f1e3c5d6..703a1029ad 100644 --- a/trillian/ctfe/logging/logger.go +++ b/trillian/ctfe/logging/logger.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "strings" "time" "github.com/google/uuid" @@ -32,13 +31,6 @@ func generateUUID() string { return uuid.New().String() } -// sanitizeLogMessage removes or replaces characters that could be used for log injection -func sanitizeLogMessage(msg string) string { - msg = strings.ReplaceAll(msg, "\n", "\\n") - msg = strings.ReplaceAll(msg, "\r", "\\r") - return msg -} - func WithContext(r *http.Request) context.Context { txID := r.Header.Get("X-Transaction-ID") if txID == "" { @@ -122,9 +114,7 @@ func LogWithContext(ctx context.Context, eventID string, msg string, fields map[ for k, v := range fields { lf[k] = v } - // Sanitize the message to prevent log injection - sanitizedMsg := sanitizeLogMessage(msg) - log.WithFields(lf).Info(sanitizedMsg) + log.WithFields(lf).Info(msg) } func LogTiming(ctx context.Context, r *http.Request, status int, elapsed time.Duration) {