diff --git a/go.mod b/go.mod index 4b7c69b9db..2a689b46de 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang/mock v1.7.0-rc.1 github.com/google/go-cmp v0.7.0 github.com/google/trillian v1.7.2 + github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jackc/pgx/v5 v5.7.5 @@ -19,6 +20,7 @@ require ( github.com/prometheus/client_golang v1.22.0 github.com/rs/cors v1.11.1 github.com/sergi/go-diff v1.4.0 + github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce @@ -71,7 +73,6 @@ require ( github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.5 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gorilla/websocket v1.5.1 // indirect @@ -101,7 +102,6 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/prometheus v0.51.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/stretchr/testify v1.10.0 // indirect diff --git a/trillian/ctfe/ct_server/main.go b/trillian/ctfe/ct_server/main.go index 1410faf29d..70b3d9ff73 100644 --- a/trillian/ctfe/ct_server/main.go +++ b/trillian/ctfe/ct_server/main.go @@ -33,6 +33,8 @@ import ( "syscall" "time" + "github.com/google/certificate-transparency-go/trillian/ctfe/logging" + "github.com/google/certificate-transparency-go/trillian/ctfe" "github.com/google/certificate-transparency-go/trillian/ctfe/cache" "github.com/google/certificate-transparency-go/trillian/ctfe/configpb" @@ -136,6 +138,7 @@ func main() { } klog.CopyStandardLogTo("WARNING") + klog.Info("**** CT HTTP Server Starting ****") metricsAt := *metricsEndpoint @@ -448,7 +451,7 @@ func setupAndRegister(ctx context.Context, client trillian.TrillianLogClient, de return nil, err } for path, handler := range inst.Handlers { - mux.Handle(lhp+path, handler) + mux.Handle(lhp+path, logging.Middleware(handler)) } return inst, nil } diff --git a/trillian/ctfe/handlers.go b/trillian/ctfe/handlers.go index 23c5894b84..81148999f9 100644 --- a/trillian/ctfe/handlers.go +++ b/trillian/ctfe/handlers.go @@ -32,6 +32,7 @@ import ( "github.com/google/certificate-transparency-go/asn1" "github.com/google/certificate-transparency-go/tls" + "github.com/google/certificate-transparency-go/trillian/ctfe/logging" "github.com/google/certificate-transparency-go/trillian/util" "github.com/google/certificate-transparency-go/x509" "github.com/google/certificate-transparency-go/x509util" @@ -39,6 +40,7 @@ import ( "github.com/google/trillian/monitoring" "github.com/google/trillian/types" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/prototext" "k8s.io/klog/v2" @@ -198,6 +200,18 @@ func (a AppHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Many/most of the handlers forward the request on to the Log RPC server; impose a deadline // on this onward request. ctx, cancel := context.WithDeadline(logCtx, getRPCDeadlineTime(a.Info)) + txID, ok := logCtx.Value(logging.CtxKeyTxID).(string) + if !ok || txID == "" { + klog.Warning("Missing transaction_id in context") + txID = "unknown" + } + spanID, ok := logCtx.Value(logging.CtxKeySpanID).(string) + if !ok || spanID == "" { + klog.Warning("Missing span_id in context") + spanID = "unknown" + } + md := metadata.Pairs("X-Transaction-ID", txID, "X-Span-ID", spanID) + ctx = metadata.NewOutgoingContext(ctx, md) defer cancel() var err error diff --git a/trillian/ctfe/logging/logger.go b/trillian/ctfe/logging/logger.go new file mode 100644 index 0000000000..703a1029ad --- /dev/null +++ b/trillian/ctfe/logging/logger.go @@ -0,0 +1,155 @@ +package logging + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type contextKey string + +const ( + CtxKeyTxID contextKey = "transaction_id" + CtxKeySpanID contextKey = "span_id" +) + +var log = logrus.New() + +func init() { + log.Formatter = &logrus.JSONFormatter{ + TimestampFormat: "2006-01-02T15:04:05.000Z07:00", + } +} + +func generateUUID() string { + return uuid.New().String() +} + +func WithContext(r *http.Request) context.Context { + txID := r.Header.Get("X-Transaction-ID") + if txID == "" { + txID = generateUUID() + } + + spanID := generateUUID() + ctx := context.WithValue(r.Context(), CtxKeyTxID, txID) + ctx = context.WithValue(ctx, CtxKeySpanID, spanID) + return ctx +} + +func WithGRPCContext(ctx context.Context) context.Context { + // First check if there's already a transaction_id in the context (from HTTP) + txID, ok := ctx.Value(CtxKeyTxID).(string) + if !ok || txID == "" { + // If not, try to get it from gRPC metadata + txID = getFromMetadata(ctx, "X-Transaction-ID") + if txID == "" { + txID = generateUUID() + } + } + + // Check for span_id in context first + spanID, ok := ctx.Value(CtxKeySpanID).(string) + if !ok || spanID == "" { + // Always generate a new span_id for this service + // No longer checking metadata since span_id is not propagated + spanID = generateUUID() + } + + ctx = context.WithValue(ctx, CtxKeyTxID, txID) + ctx = context.WithValue(ctx, CtxKeySpanID, spanID) + return ctx +} + +func getFromMetadata(ctx context.Context, key string) string { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "" + } + values := md.Get(key) + if len(values) > 0 { + return values[0] + } + return "" +} + +// PropagateToGRPC adds the transaction_id from the context to gRPC metadata +// This ensures transaction correlation across service boundaries +// Note: span_id is NOT propagated - each service generates its own span_id +func PropagateToGRPC(ctx context.Context) context.Context { + txID := ctx.Value(CtxKeyTxID) + + if txID == nil { + return ctx + } + + mdMap := map[string]string{ + "X-Transaction-ID": txID.(string), + } + // Deliberately NOT propagating span_id - each service gets its own + + md := metadata.New(mdMap) + return metadata.NewOutgoingContext(ctx, md) +} + +func LogWithContext(ctx context.Context, eventID string, msg string, fields map[string]interface{}) { + lf := logrus.Fields{ + "event_id": eventID, + } + + // Safely extract transaction_id and span_id + if txID := ctx.Value(CtxKeyTxID); txID != nil { + lf["transaction_id"] = txID + } + if spanID := ctx.Value(CtxKeySpanID); spanID != nil { + lf["span_id"] = spanID + } + + for k, v := range fields { + lf[k] = v + } + log.WithFields(lf).Info(msg) +} + +func LogTiming(ctx context.Context, r *http.Request, status int, elapsed time.Duration) { + elapsedInMsStr := fmt.Sprintf("%dms", elapsed.Milliseconds()) + LogWithContext(ctx, "timing", "request completed", map[string]interface{}{ + "path": r.URL.Path, + "method": r.Method, + "status": status, + "elapsed": elapsedInMsStr, + }) +} + +func UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (interface{}, error) { + ctx = WithGRPCContext(ctx) + start := time.Now() + resp, err := handler(ctx, req) + elapsed := time.Since(start) + LogWithContext(ctx, "timing", "gRPC call completed", map[string]interface{}{ + "method": info.FullMethod, + "status": statusCodeFromError(err), + "elapsed": elapsed.Milliseconds(), + }) + return resp, err + } +} + +func statusCodeFromError(err error) string { + if err == nil { + return "OK" + } + return err.Error() +} diff --git a/trillian/ctfe/logging/logger_test.go b/trillian/ctfe/logging/logger_test.go new file mode 100644 index 0000000000..32c3a910eb --- /dev/null +++ b/trillian/ctfe/logging/logger_test.go @@ -0,0 +1,291 @@ +package logging + +import ( + "context" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "google.golang.org/grpc/metadata" +) + +func TestGenerateUUID(t *testing.T) { + // Test that generateUUID returns a non-empty string + uuid1 := generateUUID() + uuid2 := generateUUID() + + if uuid1 == "" { + t.Error("generateUUID returned empty string") + } + + if uuid2 == "" { + t.Error("generateUUID returned empty string") + } + + // Test that two UUIDs are different + if uuid1 == uuid2 { + t.Error("generateUUID returned the same UUID twice, should be unique") + } + + // Test that UUID has expected format (basic check for dashes) + if !strings.Contains(uuid1, "-") { + t.Error("generateUUID didn't return expected UUID format") + } +} + +func TestWithContext(t *testing.T) { + // Create a test HTTP request + req := httptest.NewRequest("GET", "/test", nil) + + // Test case 1: Request without existing transaction ID + ctx := WithContext(req) + + // Check that transaction ID was added + txID := ctx.Value(CtxKeyTxID) + if txID == nil { + t.Error("WithContext didn't add transaction ID to context") + } + + // Check that span ID was added + spanID := ctx.Value(CtxKeySpanID) + if spanID == nil { + t.Error("WithContext didn't add span ID to context") + } + + // Check that both IDs are strings + if _, ok := txID.(string); !ok { + t.Error("Transaction ID is not a string") + } + if _, ok := spanID.(string); !ok { + t.Error("Span ID is not a string") + } +} + +func TestWithContextExistingTransactionID(t *testing.T) { + // Create a test HTTP request with existing transaction ID + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Transaction-ID", "existing-tx-id") + + ctx := WithContext(req) + + // Check that the existing transaction ID was preserved + txID := ctx.Value(CtxKeyTxID) + if txID != "existing-tx-id" { + t.Errorf("Expected transaction ID 'existing-tx-id', got %v", txID) + } + + // Check that a new span ID was still generated + spanID := ctx.Value(CtxKeySpanID) + if spanID == nil { + t.Error("WithContext didn't add span ID to context") + } +} + +func TestWithGRPCContext(t *testing.T) { + // Test case 1: Empty context + ctx := context.Background() + newCtx := WithGRPCContext(ctx) + + txID := newCtx.Value(CtxKeyTxID) + spanID := newCtx.Value(CtxKeySpanID) + + if txID == nil || spanID == nil { + t.Error("WithGRPCContext didn't add IDs to empty context") + } + + // Test case 2: Context with existing values + existingCtx := context.WithValue(context.Background(), CtxKeyTxID, "existing-tx") + existingCtx = context.WithValue(existingCtx, CtxKeySpanID, "existing-span") + + newCtx2 := WithGRPCContext(existingCtx) + + if newCtx2.Value(CtxKeyTxID) != "existing-tx" { + t.Error("WithGRPCContext didn't preserve existing transaction ID") + } + if newCtx2.Value(CtxKeySpanID) != "existing-span" { + t.Error("WithGRPCContext didn't preserve existing span ID") + } +} + +func TestWithGRPCContextFromMetadata(t *testing.T) { + // Create context with gRPC metadata + md := metadata.Pairs("X-Transaction-ID", "metadata-tx-id", "X-Span-ID", "metadata-span-id") + ctx := metadata.NewIncomingContext(context.Background(), md) + + newCtx := WithGRPCContext(ctx) + + txID := newCtx.Value(CtxKeyTxID) + spanID := newCtx.Value(CtxKeySpanID) + + if txID != "metadata-tx-id" { + t.Errorf("Expected transaction ID from metadata 'metadata-tx-id', got %v", txID) + } + // Span ID should be newly generated, not from metadata (each service gets its own span) + if spanID == "metadata-span-id" { + t.Errorf("Expected new span ID to be generated, but got metadata span ID: %v", spanID) + } + if spanID == "" { + t.Error("Expected new span ID to be generated, but got empty string") + } +} + +func TestLogWithContext(t *testing.T) { + // Create a test hook and attach it to our global logger + hook := test.NewLocal(log) + + // Create context with IDs + ctx := context.WithValue(context.Background(), CtxKeyTxID, "test-tx-id") + ctx = context.WithValue(ctx, CtxKeySpanID, "test-span-id") + + // Test logging + LogWithContext(ctx, "test-event", "test message", map[string]interface{}{ + "custom_field": "custom_value", + }) + + // Check that a log entry was created + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry == nil { + t.Error("LastEntry() returned nil") + return + } + + // Check log level + if entry.Level != logrus.InfoLevel { + t.Errorf("Expected INFO level, got %v", entry.Level) + } + + // Check message + if entry.Message != "test message" { + t.Errorf("Expected message 'test message', got %v", entry.Message) + } + + // Check fields + if entry.Data["event_id"] != "test-event" { + t.Errorf("Expected event_id 'test-event', got %v", entry.Data["event_id"]) + } + if entry.Data["transaction_id"] != "test-tx-id" { + t.Errorf("Expected transaction_id 'test-tx-id', got %v", entry.Data["transaction_id"]) + } + if entry.Data["span_id"] != "test-span-id" { + t.Errorf("Expected span_id 'test-span-id', got %v", entry.Data["span_id"]) + } + if entry.Data["custom_field"] != "custom_value" { + t.Errorf("Expected custom_field 'custom_value', got %v", entry.Data["custom_field"]) + } +} + +func TestLogTiming(t *testing.T) { + // Create a test hook and attach it to our global logger + hook := test.NewLocal(log) + + // Create context with IDs + ctx := context.WithValue(context.Background(), CtxKeyTxID, "timing-tx-id") + ctx = context.WithValue(ctx, CtxKeySpanID, "timing-span-id") + + // Create a test request + req := httptest.NewRequest("GET", "/api/test", nil) + + // Test timing log + elapsed := 250 * time.Millisecond + LogTiming(ctx, req, 200, elapsed) + + // Check that a log entry was created + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry == nil { + t.Error("LastEntry() returned nil") + return + } + + // Check message + if entry.Message != "request completed" { + t.Errorf("Expected message 'request completed', got %v", entry.Message) + } + + // Check timing fields + if entry.Data["event_id"] != "timing" { + t.Error("Expected event_id to be 'timing'") + } + if entry.Data["path"] != "/api/test" { + t.Errorf("Expected path '/api/test', got %v", entry.Data["path"]) + } + if entry.Data["method"] != "GET" { + t.Errorf("Expected method 'GET', got %v", entry.Data["method"]) + } + if entry.Data["status"] != 200 { + t.Errorf("Expected status 200, got %v", entry.Data["status"]) + } + if entry.Data["elapsed"] != "250ms" { + t.Errorf("Expected elapsed '250ms', got %v", entry.Data["elapsed"]) + } +} + +func TestStatusCodeFromError(t *testing.T) { + // Test with no error + status := statusCodeFromError(nil) + if status != "OK" { + t.Errorf("Expected 'OK' for nil error, got %v", status) + } + + // Test with error + err := &testError{msg: "test error"} + status = statusCodeFromError(err) + if status != "test error" { + t.Errorf("Expected 'test error', got %v", status) + } +} + +// Helper type for testing errors +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} + +func TestPropagateToGRPC(t *testing.T) { + // Test with context containing IDs + ctx := context.WithValue(context.Background(), CtxKeyTxID, "propagate-tx-id") + ctx = context.WithValue(ctx, CtxKeySpanID, "propagate-span-id") + + newCtx := PropagateToGRPC(ctx) + + // Extract metadata from the new context + md, ok := metadata.FromOutgoingContext(newCtx) + if !ok { + t.Error("PropagateToGRPC didn't add metadata to context") + } + + // Check that transaction ID was added to metadata + txIDs := md.Get("X-Transaction-ID") + if len(txIDs) != 1 || txIDs[0] != "propagate-tx-id" { + t.Errorf("Expected X-Transaction-ID 'propagate-tx-id', got %v", txIDs) + } + + // Check that span ID was NOT added to metadata (new design) + spanIDs := md.Get("X-Span-ID") + if len(spanIDs) != 0 { + t.Errorf("Expected no X-Span-ID in metadata, but got %v", spanIDs) + } + + // Test with empty context (should return same context) + emptyCtx := context.Background() + sameCtx := PropagateToGRPC(emptyCtx) + + if sameCtx != emptyCtx { + t.Error("PropagateToGRPC should return same context when no IDs present") + } +} diff --git a/trillian/ctfe/logging/middleware.go b/trillian/ctfe/logging/middleware.go new file mode 100644 index 0000000000..b463f563a3 --- /dev/null +++ b/trillian/ctfe/logging/middleware.go @@ -0,0 +1,33 @@ +package logging + +import ( + "net/http" + "time" +) + +type statusRecorder struct { + http.ResponseWriter + statusCode int +} + +func (r *statusRecorder) WriteHeader(code int) { + r.statusCode = code + r.ResponseWriter.WriteHeader(code) +} + +// this middleware measures how long each HTTP request takes to process. +// It does this by recording the start time before calling the next handler, +// then the end time after the handler finishes. +// It logs the elapsed time, HTTP status code, and request details using LogTiming. +// This helps track request latency for monitoring and debugging. + +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + ctx := WithContext(r) + rw := &statusRecorder{ResponseWriter: w, statusCode: 200} + next.ServeHTTP(rw, r.WithContext(ctx)) + elapsed := time.Since(start) + LogTiming(ctx, r, rw.statusCode, elapsed) + }) +} diff --git a/trillian/ctfe/logging/middleware_test.go b/trillian/ctfe/logging/middleware_test.go new file mode 100644 index 0000000000..7b48ade9d0 --- /dev/null +++ b/trillian/ctfe/logging/middleware_test.go @@ -0,0 +1,328 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus/hooks/test" +) + +// TestStatusRecorder tests the statusRecorder type +func TestStatusRecorder(t *testing.T) { + // Create a test ResponseWriter + w := httptest.NewRecorder() + + // Create our statusRecorder wrapper + recorder := &statusRecorder{ + ResponseWriter: w, + statusCode: 200, // default value + } + + // Test that it starts with default status code + if recorder.statusCode != 200 { + t.Errorf("Expected default status code 200, got %d", recorder.statusCode) + } + + // Test WriteHeader method + recorder.WriteHeader(404) + + // Check that our wrapper recorded the status code + if recorder.statusCode != 404 { + t.Errorf("Expected status code 404, got %d", recorder.statusCode) + } + + // Check that the underlying ResponseWriter also got the status code + if w.Code != 404 { + t.Errorf("Expected underlying ResponseWriter code 404, got %d", w.Code) + } +} + +// TestMiddleware tests the main middleware function +func TestMiddleware(t *testing.T) { + // Create a test hook to capture log output + hook := test.NewLocal(log) + + // Create a test handler that we'll wrap with our middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that context was added to the request + txID := r.Context().Value(CtxKeyTxID) + if txID == nil { + t.Error("Request context doesn't contain transaction ID") + } + + spanID := r.Context().Value(CtxKeySpanID) + if spanID == nil { + t.Error("Request context doesn't contain span ID") + } + + // Simulate some work + time.Sleep(10 * time.Millisecond) + + // Write a response + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello World")) + }) + + // Wrap our test handler with the middleware + wrappedHandler := Middleware(testHandler) + + // Create a test request + req := httptest.NewRequest("GET", "/test-endpoint", nil) + + // Create a ResponseRecorder to capture the response + w := httptest.NewRecorder() + + // Execute the request + wrappedHandler.ServeHTTP(w, req) + + // Check that the response was written correctly + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %d", w.Code) + } + + if w.Body.String() != "Hello World" { + t.Errorf("Expected body 'Hello World', got %s", w.Body.String()) + } + + // Check that a log entry was created + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry == nil { + t.Error("LastEntry() returned nil") + return + } + + // Check log message + if entry.Message != "request completed" { + t.Errorf("Expected message 'request completed', got %s", entry.Message) + } + + // Check log fields + if entry.Data["event_id"] != "timing" { + t.Errorf("Expected event_id 'timing', got %v", entry.Data["event_id"]) + } + + if entry.Data["method"] != "GET" { + t.Errorf("Expected method 'GET', got %v", entry.Data["method"]) + } + + if entry.Data["path"] != "/test-endpoint" { + t.Errorf("Expected path '/test-endpoint', got %v", entry.Data["path"]) + } + + if entry.Data["status"] != 200 { + t.Errorf("Expected status 200, got %v", entry.Data["status"]) + } + + // Check that elapsed time was recorded (should be > 0) + elapsed, ok := entry.Data["elapsed"].(string) + if !ok { + t.Error("Elapsed time not recorded as string") + } + + if !strings.HasSuffix(elapsed, "ms") { + t.Errorf("Expected elapsed time to end with 'ms', got %s", elapsed) + } +} + +// TestMiddlewareWithExistingTransactionID tests middleware when request already has transaction ID +func TestMiddlewareWithExistingTransactionID(t *testing.T) { + hook := test.NewLocal(log) + + // Create a test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that the existing transaction ID was preserved + txID := r.Context().Value(CtxKeyTxID) + if txID != "existing-tx-123" { + t.Errorf("Expected existing transaction ID 'existing-tx-123', got %v", txID) + } + + w.WriteHeader(http.StatusOK) + }) + + // Wrap with middleware + wrappedHandler := Middleware(testHandler) + + // Create request with existing transaction ID + req := httptest.NewRequest("POST", "/api/submit", strings.NewReader("test data")) + req.Header.Set("X-Transaction-ID", "existing-tx-123") + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + // Check log entry + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry.Data["transaction_id"] != "existing-tx-123" { + t.Errorf("Expected transaction_id 'existing-tx-123', got %v", entry.Data["transaction_id"]) + } +} + +// TestMiddlewareWithDifferentStatusCodes tests middleware with various HTTP status codes +func TestMiddlewareWithDifferentStatusCodes(t *testing.T) { + testCases := []struct { + name string + statusCode int + expectedStatus int + }{ + {"Success", http.StatusOK, 200}, + {"Created", http.StatusCreated, 201}, + {"BadRequest", http.StatusBadRequest, 400}, + {"NotFound", http.StatusNotFound, 404}, + {"InternalServerError", http.StatusInternalServerError, 500}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hook := test.NewLocal(log) + + // Create handler that returns specific status code + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.statusCode) + }) + + wrappedHandler := Middleware(testHandler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check response status + if w.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code) + } + + // Check log entry + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + if entry.Data["status"] != tc.expectedStatus { + t.Errorf("Expected logged status %d, got %v", tc.expectedStatus, entry.Data["status"]) + } + }) + } +} + +// TestMiddlewareWithPanic tests that middleware handles panics gracefully +func TestMiddlewareWithPanic(t *testing.T) { + // Create handler that panics + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + wrappedHandler := Middleware(testHandler) + + req := httptest.NewRequest("GET", "/panic", nil) + w := httptest.NewRecorder() + + // This should panic, so we need to recover + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic, but none occurred") + } + }() + + wrappedHandler.ServeHTTP(w, req) +} + +// TestMiddlewareTimingAccuracy tests that timing is reasonably accurate +func TestMiddlewareTimingAccuracy(t *testing.T) { + hook := test.NewLocal(log) + + // Create handler that sleeps for a known duration + sleepDuration := 50 * time.Millisecond + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(sleepDuration) + w.WriteHeader(http.StatusOK) + }) + + wrappedHandler := Middleware(testHandler) + + req := httptest.NewRequest("GET", "/slow", nil) + w := httptest.NewRecorder() + + start := time.Now() + wrappedHandler.ServeHTTP(w, req) + actualElapsed := time.Since(start) + + // Check that timing was logged + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + return + } + + entry := hook.LastEntry() + elapsedStr, ok := entry.Data["elapsed"].(string) + if !ok { + t.Error("Elapsed time not recorded as string") + return + } + + // Parse the elapsed time (remove "ms" suffix) + elapsedStr = strings.TrimSuffix(elapsedStr, "ms") + + // Check that the timing is reasonably accurate using proper numeric comparison + // Allow for some variance (±20ms) due to system scheduling and load + tolerance := 20 * time.Millisecond + if actualElapsed < (sleepDuration-tolerance) || actualElapsed > (sleepDuration+tolerance) { + t.Errorf("Expected elapsed time around %v (±%v), got %v. Logged: %sms", + sleepDuration, tolerance, actualElapsed, elapsedStr) + } +} + +// TestMiddlewareChaining tests that multiple middlewares can be chained +func TestMiddlewareChaining(t *testing.T) { + hook := test.NewLocal(log) + + // Create a simple handler + finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("final")) + }) + + // Create another middleware for testing chaining + authMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add a custom header to verify this middleware ran + w.Header().Set("X-Auth-Middleware", "ran") + next.ServeHTTP(w, r) + }) + } + + // Chain middlewares: authMiddleware -> Middleware -> finalHandler + chainedHandler := authMiddleware(Middleware(finalHandler)) + + req := httptest.NewRequest("GET", "/chained", nil) + w := httptest.NewRecorder() + + chainedHandler.ServeHTTP(w, req) + + // Check that both middlewares ran + if w.Header().Get("X-Auth-Middleware") != "ran" { + t.Error("Auth middleware didn't run") + } + + if w.Body.String() != "final" { + t.Errorf("Expected body 'final', got %s", w.Body.String()) + } + + // Check that our logging middleware created a log entry + if len(hook.Entries) != 1 { + t.Errorf("Expected 1 log entry, got %d", len(hook.Entries)) + } +}