diff --git a/contrib/database/sql/internal/dsn.go b/contrib/database/sql/internal/dsn.go index 7c15ae5c2a..babe32ae90 100644 --- a/contrib/database/sql/internal/dsn.go +++ b/contrib/database/sql/internal/dsn.go @@ -263,7 +263,7 @@ func isValidHostnameStart(s string) bool { } // Should contain hostname-like patterns return strings.Contains(s, ".") || strings.Contains(s, ":") || - strings.Contains(s, "/") || s == strings.TrimSpace(s) + strings.Contains(s, "/") || s == strings.TrimSpace(s) } // sanitizeMySQLPasswords sanitizes passwords in MySQL DSN format (user:pass@tcp...). diff --git a/ddtrace/tracer/civisibility_payload.go b/ddtrace/tracer/civisibility_payload.go index 4343f318b9..b04ee29a5e 100644 --- a/ddtrace/tracer/civisibility_payload.go +++ b/ddtrace/tracer/civisibility_payload.go @@ -7,22 +7,22 @@ package tracer import ( "bytes" - "sync/atomic" "time" + "github.com/tinylib/msgp/msgp" + "github.com/DataDog/dd-trace-go/v2/internal/civisibility/constants" "github.com/DataDog/dd-trace-go/v2/internal/civisibility/utils" "github.com/DataDog/dd-trace-go/v2/internal/civisibility/utils/telemetry" "github.com/DataDog/dd-trace-go/v2/internal/globalconfig" "github.com/DataDog/dd-trace-go/v2/internal/log" "github.com/DataDog/dd-trace-go/v2/internal/version" - "github.com/tinylib/msgp/msgp" ) // ciVisibilityPayload represents a payload specifically designed for CI Visibility events. -// It embeds the generic payload structure and adds methods to handle CI Visibility specific data. +// It uses the generic payload interface and adds methods to handle CI Visibility specific data. type ciVisibilityPayload struct { - *payloadV04 + payload payload serializationTime time.Duration } @@ -36,18 +36,17 @@ type ciVisibilityPayload struct { // Returns: // // An error if encoding the event fails. -func (p *ciVisibilityPayload) push(event *ciVisibilityEvent) error { - p.buf.Grow(event.Msgsize()) +func (p *ciVisibilityPayload) push(event *ciVisibilityEvent) (size int, err error) { + p.payload.grow(event.Msgsize()) startTime := time.Now() defer func() { p.serializationTime += time.Since(startTime) }() - if err := msgp.Encode(&p.buf, event); err != nil { - return err + if err := msgp.Encode(p.payload, event); err != nil { + return 0, err } - atomic.AddUint32(&p.count, 1) - p.updateHeader() - return nil + p.payload.recordItem() // This already calls updateHeader() internally. + return p.size(), nil } // newCiVisibilityPayload creates a new instance of civisibilitypayload. @@ -57,7 +56,7 @@ func (p *ciVisibilityPayload) push(event *ciVisibilityEvent) error { // A pointer to a newly initialized civisibilitypayload instance. func newCiVisibilityPayload() *ciVisibilityPayload { log.Debug("ciVisibilityPayload: creating payload instance") - return &ciVisibilityPayload{newPayload(), 0} + return &ciVisibilityPayload{payload: newPayload(traceProtocolV04), serializationTime: 0} } // getBuffer retrieves the complete body of the CI Visibility payload, including metadata. @@ -73,11 +72,11 @@ func newCiVisibilityPayload() *ciVisibilityPayload { // An error if reading from the buffer or encoding the payload fails. func (p *ciVisibilityPayload) getBuffer(config *config) (*bytes.Buffer, error) { startTime := time.Now() - log.Debug("ciVisibilityPayload: .getBuffer (count: %d)", p.itemCount()) + log.Debug("ciVisibilityPayload: .getBuffer (count: %d)", p.payload.stats().itemCount) // Create a buffer to read the current payload payloadBuf := new(bytes.Buffer) - if _, err := payloadBuf.ReadFrom(p.payloadV04); err != nil { + if _, err := payloadBuf.ReadFrom(p.payload); err != nil { return nil, err } @@ -90,7 +89,7 @@ func (p *ciVisibilityPayload) getBuffer(config *config) (*bytes.Buffer, error) { return nil, err } - telemetry.EndpointPayloadEventsCount(telemetry.TestCycleEndpointType, float64(p.itemCount())) + telemetry.EndpointPayloadEventsCount(telemetry.TestCycleEndpointType, float64(p.payload.stats().itemCount)) telemetry.EndpointPayloadBytes(telemetry.TestCycleEndpointType, float64(encodedBuf.Len())) telemetry.EndpointEventsSerializationMs(telemetry.TestCycleEndpointType, float64((p.serializationTime + time.Since(startTime)).Milliseconds())) return encodedBuf, nil @@ -150,3 +149,43 @@ func (p *ciVisibilityPayload) writeEnvelope(env string, events []byte) *ciTestCy return visibilityPayload } + +// stats returns the current stats of the payload. +func (p *ciVisibilityPayload) stats() payloadStats { + return p.payload.stats() +} + +// size returns the payload size in bytes (for backward compatibility). +func (p *ciVisibilityPayload) size() int { + return p.payload.size() +} + +// itemCount returns the number of items available in the stream (for backward compatibility). +func (p *ciVisibilityPayload) itemCount() int { + return p.payload.itemCount() +} + +// protocol returns the protocol version of the payload. +func (p *ciVisibilityPayload) protocol() float64 { + return p.payload.protocol() +} + +// clear empties the payload buffers. +func (p *ciVisibilityPayload) clear() { + p.payload.clear() +} + +// reset sets up the payload to be read a second time. +func (p *ciVisibilityPayload) reset() { + p.payload.reset() +} + +// Read implements io.Reader by reading from the underlying payload. +func (p *ciVisibilityPayload) Read(b []byte) (n int, err error) { + return p.payload.Read(b) +} + +// Close implements io.Closer by closing the underlying payload. +func (p *ciVisibilityPayload) Close() error { + return p.payload.Close() +} diff --git a/ddtrace/tracer/civisibility_payload_test.go b/ddtrace/tracer/civisibility_payload_test.go index 3083fb672e..5be19bed0d 100644 --- a/ddtrace/tracer/civisibility_payload_test.go +++ b/ddtrace/tracer/civisibility_payload_test.go @@ -11,7 +11,6 @@ import ( "io" "strconv" "strings" - "sync/atomic" "testing" "github.com/DataDog/dd-trace-go/v2/internal/civisibility/constants" @@ -55,8 +54,9 @@ func TestCiVisibilityPayloadIntegrity(t *testing.T) { want.Reset() err := msgp.Encode(want, allEvents) assert.NoError(err) - assert.Equal(want.Len(), p.size()) - assert.Equal(p.itemCount(), len(allEvents)) + stats := p.stats() + assert.Equal(want.Len(), stats.size) + assert.Equal(len(allEvents), stats.itemCount) got, err := io.ReadAll(p) assert.NoError(err) @@ -152,15 +152,12 @@ func benchmarkCiVisibilityPayloadThroughput(count int) func(*testing.B) { b.ReportAllocs() b.ResetTimer() reset := func() { - p.header = make([]byte, 8) - p.off = 8 - atomic.StoreUint32(&p.count, 0) - p.buf.Reset() + p = newCiVisibilityPayload() } for i := 0; i < b.N; i++ { reset() for _, event := range events { - for p.size() < payloadMaxLimit { + for p.stats().size < payloadMaxLimit { p.push(event) } } diff --git a/ddtrace/tracer/civisibility_transport.go b/ddtrace/tracer/civisibility_transport.go index d577a89459..a9fc05b7c9 100644 --- a/ddtrace/tracer/civisibility_transport.go +++ b/ddtrace/tracer/civisibility_transport.go @@ -128,8 +128,8 @@ func newCiVisibilityTransport(config *config) *ciVisibilityTransport { // Returns: // // An io.ReadCloser for reading the response body, and an error if the operation fails. -func (t *ciVisibilityTransport) send(p *payloadV04) (body io.ReadCloser, err error) { - ciVisibilityPayload := &ciVisibilityPayload{p, 0} +func (t *ciVisibilityTransport) send(p payload) (body io.ReadCloser, err error) { + ciVisibilityPayload := &ciVisibilityPayload{payload: p, serializationTime: 0} buffer, bufferErr := ciVisibilityPayload.getBuffer(t.config) if bufferErr != nil { return nil, fmt.Errorf("cannot create buffer payload: %v", bufferErr) diff --git a/ddtrace/tracer/civisibility_transport_test.go b/ddtrace/tracer/civisibility_transport_test.go index 644e3f75f5..7cfa416998 100644 --- a/ddtrace/tracer/civisibility_transport_test.go +++ b/ddtrace/tracer/civisibility_transport_test.go @@ -102,12 +102,12 @@ func runTransportTest(t *testing.T, agentless, shouldSetAPIKey bool) { p := newCiVisibilityPayload() for _, t := range tc.payload { for _, span := range t { - err := p.push(getCiVisibilityEvent(span)) + _, err := p.push(getCiVisibilityEvent(span)) assert.NoError(err) } } - _, err := transport.send(p.payloadV04) + _, err := transport.send(p.payload) assert.NoError(err) } assert.Equal(hits, len(testCases)) diff --git a/ddtrace/tracer/civisibility_writer.go b/ddtrace/tracer/civisibility_writer.go index bf85872026..038e8e6089 100644 --- a/ddtrace/tracer/civisibility_writer.go +++ b/ddtrace/tracer/civisibility_writer.go @@ -64,10 +64,11 @@ func (w *ciVisibilityTraceWriter) add(trace []*Span) { telemetry.EventsEnqueueForSerialization() for _, s := range trace { cvEvent := getCiVisibilityEvent(s) - if err := w.payload.push(cvEvent); err != nil { + size, err := w.payload.push(cvEvent) + if err != nil { log.Error("ciVisibilityTraceWriter: Error encoding msgpack: %s", err.Error()) } - if w.payload.size() > agentlessPayloadSizeLimit { + if size > agentlessPayloadSizeLimit { w.flush() } } @@ -82,7 +83,7 @@ func (w *ciVisibilityTraceWriter) stop() { // flush sends the current payload to the transport. It ensures that the payload is reset // and the resources are freed after the flush operation is completed. func (w *ciVisibilityTraceWriter) flush() { - if w.payload.itemCount() == 0 { + if w.payload.stats().itemCount == 0 { return } @@ -113,9 +114,10 @@ func (w *ciVisibilityTraceWriter) flush() { telemetry.EndpointPayloadRequests(telemetry.TestCycleEndpointType, requestCompressedType) for attempt := 0; attempt <= w.config.sendRetries; attempt++ { - size, count = p.size(), p.itemCount() + stats := p.stats() + size, count = stats.size, stats.itemCount log.Debug("ciVisibilityTraceWriter: sending payload: size: %d events: %d\n", size, count) - _, err = w.config.transport.send(p.payloadV04) + _, err = w.config.transport.send(p.payload) if err == nil { log.Debug("ciVisibilityTraceWriter: sent events after %d attempts", attempt+1) return diff --git a/ddtrace/tracer/civisibility_writer_test.go b/ddtrace/tracer/civisibility_writer_test.go index e8f4858ba5..113f6bd5ca 100644 --- a/ddtrace/tracer/civisibility_writer_test.go +++ b/ddtrace/tracer/civisibility_writer_test.go @@ -30,10 +30,10 @@ type failingCiVisibilityTransport struct { assert *assert.Assertions } -func (t *failingCiVisibilityTransport) send(p *payloadV04) (io.ReadCloser, error) { +func (t *failingCiVisibilityTransport) send(p payload) (io.ReadCloser, error) { t.sendAttempts++ - ciVisibilityPayload := &ciVisibilityPayload{p, 0} + ciVisibilityPayload := &ciVisibilityPayload{payload: p, serializationTime: 0} var events ciVisibilityEvents err := msgp.Decode(ciVisibilityPayload, &events) diff --git a/ddtrace/tracer/option.go b/ddtrace/tracer/option.go index 3095de6b1b..e241b8aa77 100644 --- a/ddtrace/tracer/option.go +++ b/ddtrace/tracer/option.go @@ -27,6 +27,8 @@ import ( "golang.org/x/mod/semver" pb "github.com/DataDog/datadog-agent/pkg/proto/pbgo/trace" + "github.com/tinylib/msgp/msgp" + "github.com/DataDog/dd-trace-go/v2/ddtrace/ext" "github.com/DataDog/dd-trace-go/v2/internal" appsecconfig "github.com/DataDog/dd-trace-go/v2/internal/appsec/config" @@ -40,7 +42,6 @@ import ( "github.com/DataDog/dd-trace-go/v2/internal/telemetry" "github.com/DataDog/dd-trace-go/v2/internal/traceprof" "github.com/DataDog/dd-trace-go/v2/internal/version" - "github.com/tinylib/msgp/msgp" "github.com/DataDog/datadog-go/v5/statsd" ) @@ -1505,7 +1506,7 @@ func (t *dummyTransport) ObfuscationVersion() int { return t.obfVersion } -func (t *dummyTransport) send(p *payloadV04) (io.ReadCloser, error) { +func (t *dummyTransport) send(p payload) (io.ReadCloser, error) { traces, err := decode(p) if err != nil { return nil, err @@ -1521,7 +1522,7 @@ func (t *dummyTransport) endpoint() string { return "http://localhost:9/v0.4/traces" } -func decode(p *payloadV04) (spanLists, error) { +func decode(p payloadReader) (spanLists, error) { var traces spanLists err := msgp.Decode(p, &traces) return traces, err diff --git a/ddtrace/tracer/payload.go b/ddtrace/tracer/payload.go index f6dca635d2..d7dbf10be4 100644 --- a/ddtrace/tracer/payload.go +++ b/ddtrace/tracer/payload.go @@ -5,6 +5,64 @@ package tracer +import ( + "io" + "sync" +) + +// payloadStats contains the statistics of a payload. +type payloadStats struct { + size int // size in bytes + itemCount int // number of items (traces) +} + +// payloadWriter defines the interface for writing data to a payload. +type payloadWriter interface { + io.Writer + + push(t spanList) (stats payloadStats, err error) + grow(n int) + reset() + clear() + + // recordItem records that an item was added and updates the header + recordItem() +} + +// payloadReader defines the interface for reading data from a payload. +type payloadReader interface { + io.Reader + io.Closer + + stats() payloadStats + size() int + itemCount() int + protocol() float64 +} + +// unsafePayload defines the interface for unsafe (non-thread-safe) payload implementations. +type unsafePayload interface { + io.Reader + io.Writer + io.Closer + + push(t spanList) (stats payloadStats, err error) + itemCount() int + size() int + reset() + clear() + grow(n int) + recordItem() + stats() payloadStats + protocol() float64 +} + +// payload combines both reading and writing operations for a payload. +type payload interface { + payloadWriter + payloadReader +} + // https://github.com/msgpack/msgpack/blob/master/spec.md#array-format-family const ( msgpackArrayFix byte = 144 // up to 15 items @@ -12,6 +70,104 @@ const ( msgpackArray32 byte = 0xdd // up to 2^32-1 items, followed by size in 4 bytes ) +// newPayload returns a ready to use thread-safe payload. +func newPayload(protocol float64) payload { + // TODO(hannahkm): add support for v1 protocol + // if protocol == traceProtocolV1 { + // } + return &safePayload{ + p: newPayloadV04(protocol), + } +} + +// safePayload provides a thread-safe wrapper around unsafePayload. +type safePayload struct { + mu sync.RWMutex + p unsafePayload +} + +// push pushes a new item into the stream in a thread-safe manner. +func (sp *safePayload) push(t spanList) (stats payloadStats, err error) { + sp.mu.Lock() + defer sp.mu.Unlock() + return sp.p.push(t) +} + +// itemCount returns the number of items available in the stream in a thread-safe manner. +func (sp *safePayload) itemCount() int { + return sp.p.itemCount() +} + +// size returns the payload size in bytes in a thread-safe manner. +func (sp *safePayload) size() int { + sp.mu.RLock() + defer sp.mu.RUnlock() + return sp.p.size() +} + +// reset sets up the payload to be read a second time in a thread-safe manner. +func (sp *safePayload) reset() { + sp.mu.Lock() + defer sp.mu.Unlock() + sp.p.reset() +} + +// clear empties the payload buffers in a thread-safe manner. +func (sp *safePayload) clear() { + sp.mu.Lock() + defer sp.mu.Unlock() + sp.p.clear() +} + +// Read implements io.Reader in a thread-safe manner. +func (sp *safePayload) Read(b []byte) (n int, err error) { + // Note: Read modifies internal state (off, reader), so we need full lock + sp.mu.Lock() + defer sp.mu.Unlock() + return sp.p.Read(b) +} + +// Close implements io.Closer in a thread-safe manner. +func (sp *safePayload) Close() error { + sp.mu.Lock() + defer sp.mu.Unlock() + return sp.p.Close() +} + +// Write implements io.Writer in a thread-safe manner. +func (sp *safePayload) Write(data []byte) (n int, err error) { + sp.mu.Lock() + defer sp.mu.Unlock() + return sp.p.Write(data) +} + +// grow grows the buffer to ensure it can accommodate n more bytes in a thread-safe manner. +func (sp *safePayload) grow(n int) { + sp.mu.Lock() + defer sp.mu.Unlock() + sp.p.grow(n) +} + +// recordItem records that an item was added and updates the header in a thread-safe manner. +func (sp *safePayload) recordItem() { + sp.mu.Lock() + defer sp.mu.Unlock() + sp.p.recordItem() +} + +// stats returns the current stats of the payload in a thread-safe manner. +func (sp *safePayload) stats() payloadStats { + sp.mu.RLock() + defer sp.mu.RUnlock() + return sp.p.stats() +} + +// protocol returns the protocol version of the payload in a thread-safe manner. +func (sp *safePayload) protocol() float64 { + // Protocol is immutable after creation - no lock needed + return sp.p.protocol() +} + // traceChunk represents a list of spans with the same trace ID, // i.e. a chunk of a trace type traceChunk struct { diff --git a/ddtrace/tracer/payload_test.go b/ddtrace/tracer/payload_test.go index 7f69ccfb30..702c2bd141 100644 --- a/ddtrace/tracer/payload_test.go +++ b/ddtrace/tracer/payload_test.go @@ -7,9 +7,11 @@ package tracer import ( "bytes" + "fmt" "io" "strconv" "strings" + "sync" "sync/atomic" "testing" @@ -37,18 +39,19 @@ func TestPayloadIntegrity(t *testing.T) { for _, n := range []int{10, 1 << 10, 1 << 17} { t.Run(strconv.Itoa(n), func(t *testing.T) { assert := assert.New(t) - p := newPayload() + p := newPayload(traceProtocolV04) lists := make(spanLists, n) for i := 0; i < n; i++ { list := newSpanList(i%5 + 1) lists[i] = list - p.push(list) + _, _ = p.push(list) } want.Reset() err := msgp.Encode(want, lists) assert.NoError(err) - assert.Equal(want.Len(), p.size()) - assert.Equal(p.itemCount(), n) + stats := p.stats() + assert.Equal(want.Len(), stats.size) + assert.Equal(n, stats.itemCount) got, err := io.ReadAll(p) assert.NoError(err) @@ -63,9 +66,9 @@ func TestPayloadDecode(t *testing.T) { for _, n := range []int{10, 1 << 10} { t.Run(strconv.Itoa(n), func(t *testing.T) { assert := assert.New(t) - p := newPayload() + p := newPayload(traceProtocolV04) for i := 0; i < n; i++ { - p.push(newSpanList(i%5 + 1)) + _, _ = p.push(newSpanList(i%5 + 1)) } var got spanLists err := msgp.Decode(p, &got) @@ -85,7 +88,7 @@ func BenchmarkPayloadThroughput(b *testing.B) { // payload is filled. func benchmarkPayloadThroughput(count int) func(*testing.B) { return func(b *testing.B) { - p := newPayload() + p := newPayloadV04(traceProtocolV04) s := newBasicSpan("X") s.meta["key"] = strings.Repeat("X", 10*1024) trace := make(spanList, count) @@ -102,9 +105,241 @@ func benchmarkPayloadThroughput(count int) func(*testing.B) { } for i := 0; i < b.N; i++ { reset() - for p.size() < payloadMaxLimit { - p.push(trace) + for p.stats().size < payloadMaxLimit { + _, _ = p.push(trace) } } } } + +// TestPayloadConcurrentAccess tests that payload operations are safe for concurrent use +func TestPayloadConcurrentAccess(t *testing.T) { + p := newPayload(traceProtocolV04) + + // Create some test spans + spans := make(spanList, 10) + for i := 0; i < 10; i++ { + spans[i] = newBasicSpan("test-span") + } + + var wg sync.WaitGroup + + // Start multiple goroutines that perform concurrent operations + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + // Push some spans + for j := 0; j < 5; j++ { + _, _ = p.push(spans) + } + + // Read size and item count concurrently + for j := 0; j < 10; j++ { + stats := p.stats() + _ = stats.size + _ = stats.itemCount + } + }() + } + + // Also perform operations from the main goroutine + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 20; i++ { + _ = p.stats().size + } + }() + + wg.Wait() + + // Verify the payload is in a consistent state + if p.stats().itemCount == 0 { + t.Error("Expected payload to have items after concurrent operations") + } + + if p.stats().size <= 0 { + t.Error("Expected payload size to be positive after concurrent operations") + } +} + +// TestPayloadConcurrentReadWrite tests concurrent read and write operations +func TestPayloadConcurrentReadWrite(t *testing.T) { + p := newPayload(traceProtocolV04) + + // Add some initial data + span := newBasicSpan("test") + spans := spanList{span} + _, _ = p.push(spans) + + var wg sync.WaitGroup + + // Concurrent writers + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + _, _ = p.push(spans) + } + }() + } + + // Concurrent readers + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 1024) + for j := 0; j < 10; j++ { + p.reset() + _, _ = p.Read(buf) + } + }() + } + + // Concurrent size/count checkers + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 20; j++ { + stats := p.stats() + _ = stats.size + _ = stats.itemCount + } + }() + } + + wg.Wait() + + // Verify final state + if p.stats().itemCount == 0 { + t.Error("Expected payload to have items") + } +} + +func BenchmarkPayloadPush(b *testing.B) { + sizes := []struct { + name string + numSpans int + spanSize int + }{ + {"1span_1KB", 1, 1}, + {"5span_1KB", 5, 1}, + {"10span_1KB", 10, 1}, + {"1span_10KB", 1, 10}, + {"5span_10KB", 5, 10}, + {"10span_50KB", 10, 50}, + } + + for _, size := range sizes { + b.Run(size.name, func(b *testing.B) { + spans := make(spanList, size.numSpans) + for i := 0; i < size.numSpans; i++ { + span := newBasicSpan("benchmark-span") + span.meta["data"] = strings.Repeat("x", size.spanSize*1024) + spans[i] = span + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + p := newPayloadV04(traceProtocolV04) + _, _ = p.push(spans) + } + }) + } +} + +func BenchmarkPayloadStats(b *testing.B) { + tests := []struct { + name string + numTraces int + spansPer int + }{ + {"empty", 0, 0}, + {"small_1trace_1span", 1, 1}, + {"medium_10trace_5span", 10, 5}, + {"large_100trace_10span", 100, 10}, + } + + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + p := newPayload(traceProtocolV04) + + for i := 0; i < test.numTraces; i++ { + spans := make(spanList, test.spansPer) + for j := 0; j < test.spansPer; j++ { + spans[j] = newBasicSpan("test-span") + } + _, _ = p.push(spans) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + stats := p.stats() + _ = stats.size + _ = stats.itemCount + } + }) + } +} + +func BenchmarkPayloadConcurrentAccess(b *testing.B) { + concurrencyLevels := []int{1, 2, 4, 8} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("concurrency_%d", concurrency), func(b *testing.B) { + p := newPayload(traceProtocolV04) + span := newBasicSpan("concurrent-test") + spans := spanList{span} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + + for j := 0; j < concurrency; j++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = p.push(spans) + }() + } + + for j := 0; j < concurrency; j++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = p.stats() + }() + } + + wg.Wait() + p.clear() + } + }) + } +} + +func TestMsgsizeAnalysis(t *testing.T) { + sizes := []int{1, 5, 10} + for _, numSpans := range sizes { + spans := make(spanList, numSpans) + for i := 0; i < numSpans; i++ { + span := newBasicSpan("test") + span.meta["data"] = strings.Repeat("x", 1024) + spans[i] = span + } + + msgsize := spans.Msgsize() + t.Logf("%d spans with 1KB each: msgsize=%d bytes", numSpans, msgsize) + } +} diff --git a/ddtrace/tracer/payload_v04.go b/ddtrace/tracer/payload_v04.go index 1cb9063d56..b136b9fe32 100644 --- a/ddtrace/tracer/payload_v04.go +++ b/ddtrace/tracer/payload_v04.go @@ -21,8 +21,8 @@ import ( // from the msgpack array spec: // https://github.com/msgpack/msgpack/blob/master/spec.md#array-format-family // -// payloadV04 implements io.Reader and can be used with the decoder directly. To create -// a new payload use the newPayload method. +// payloadV04 implements unsafePayload and can be used with the decoder directly. To create +// a new payload use the newPayloadV04 method. // // payloadV04 is not safe for concurrent use. // @@ -58,35 +58,30 @@ type payloadV04 struct { // reader is used for reading the contents of buf. reader *bytes.Reader - // protocol specifies the trace protocol to use. - protocol float64 + // protocolVersion specifies the trace protocol to use. + protocolVersion float64 } var _ io.Reader = (*payloadV04)(nil) -// newPayload returns a ready to use payload. -func newPayload() *payloadV04 { +// newPayloadV04 returns a ready to use payload. +func newPayloadV04(protocol float64) *payloadV04 { p := &payloadV04{ - header: make([]byte, 8), - off: 8, + header: make([]byte, 8), + off: 8, + protocolVersion: protocol, } return p } // push pushes a new item into the stream. -func (p *payloadV04) push(t []*Span) error { - // if p.protocol == traceProtocolV1 { - // // TODO: implement v1.0 encoding - // } else { - sl := spanList(t) - // } - p.buf.Grow(sl.Msgsize()) - if err := msgp.Encode(&p.buf, sl); err != nil { - return err +func (p *payloadV04) push(t spanList) (stats payloadStats, err error) { + p.buf.Grow(t.Msgsize()) + if err := msgp.Encode(&p.buf, t); err != nil { + return payloadStats{}, err } - atomic.AddUint32(&p.count, 1) - p.updateHeader() - return nil + p.recordItem() + return p.stats(), nil } // itemCount returns the number of items available in the stream. @@ -116,6 +111,30 @@ func (p *payloadV04) clear() { p.reader = nil } +// grow grows the buffer to ensure it can accommodate n more bytes. +func (p *payloadV04) grow(n int) { + p.buf.Grow(n) +} + +// recordItem records that an item was added and updates the header. +func (p *payloadV04) recordItem() { + atomic.AddUint32(&p.count, 1) + p.updateHeader() +} + +// stats returns the current stats of the payload. +func (p *payloadV04) stats() payloadStats { + return payloadStats{ + size: p.size(), + itemCount: int(atomic.LoadUint32(&p.count)), + } +} + +// protocol returns the protocol version of the payload. +func (p *payloadV04) protocol() float64 { + return p.protocolVersion +} + // updateHeader updates the payload header based on the number of items currently // present in the stream. func (p *payloadV04) updateHeader() { @@ -140,6 +159,11 @@ func (p *payloadV04) Close() error { return nil } +// Write implements io.Writer. It writes data directly to the buffer. +func (p *payloadV04) Write(data []byte) (n int, err error) { + return p.buf.Write(data) +} + // Read implements io.Reader. It reads from the msgpack-encoded stream. func (p *payloadV04) Read(b []byte) (n int, err error) { if p.off < len(p.header) { diff --git a/ddtrace/tracer/payload_v1.go b/ddtrace/tracer/payload_v1.go index d94734f492..77e6bc4df8 100644 --- a/ddtrace/tracer/payload_v1.go +++ b/ddtrace/tracer/payload_v1.go @@ -51,6 +51,10 @@ type payloadV1 struct { // a list of trace `chunks` chunks []traceChunk + + // fields needed to implement unsafePayload interface + protocolVersion float64 + itemsCount uint32 } // AnyValue is a representation of the `any` value. It can take the following types: @@ -86,3 +90,13 @@ type keyValue struct { } type keyValueList = []keyValue + +// newPayloadV1 returns a ready to use payloadV1. +func newPayloadV1(protocol float64) *payloadV1 { + return &payloadV1{ + protocolVersion: protocol, + strings: make([]string, 0), + attributes: make(map[uint32]anyValue), + chunks: make([]traceChunk, 0), + } +} diff --git a/ddtrace/tracer/span_test.go b/ddtrace/tracer/span_test.go index 6b42b80305..d383586965 100644 --- a/ddtrace/tracer/span_test.go +++ b/ddtrace/tracer/span_test.go @@ -174,7 +174,7 @@ func TestSpanFinishTwice(t *testing.T) { assert.Nil(err) defer stop() - assert.Equal(tracer.traceWriter.(*agentTraceWriter).payload.itemCount(), 0) + assert.Equal(tracer.traceWriter.(*agentTraceWriter).payload.stats().itemCount, 0) // the finish must be idempotent span := tracer.newRootSpan("pylons.request", "pylons", "/") diff --git a/ddtrace/tracer/tracer_test.go b/ddtrace/tracer/tracer_test.go index 77197e6a9e..20b3531ae9 100644 --- a/ddtrace/tracer/tracer_test.go +++ b/ddtrace/tracer/tracer_test.go @@ -106,7 +106,7 @@ loop: case <-timeout: tst.Fatalf("timed out waiting for payload to contain %d", n) default: - if t.traceWriter.(*agentTraceWriter).payload.itemCount() == n { + if t.traceWriter.(*agentTraceWriter).payload.stats().itemCount == n { break loop } time.Sleep(10 * time.Millisecond) @@ -1447,7 +1447,7 @@ func TestTracerEdgeSampler(t *testing.T) { span1.Finish() } - assert.Equal(tracer0.traceWriter.(*agentTraceWriter).payload.itemCount(), 0) + assert.Equal(tracer0.traceWriter.(*agentTraceWriter).payload.stats().itemCount, 0) tracer1.awaitPayload(t, count) } diff --git a/ddtrace/tracer/transport.go b/ddtrace/tracer/transport.go index c6746e084b..1046392dc9 100644 --- a/ddtrace/tracer/transport.go +++ b/ddtrace/tracer/transport.go @@ -17,6 +17,7 @@ import ( "time" pb "github.com/DataDog/datadog-agent/pkg/proto/pbgo/trace" + "github.com/DataDog/dd-trace-go/v2/ddtrace/internal/tracerstats" "github.com/DataDog/dd-trace-go/v2/internal" "github.com/DataDog/dd-trace-go/v2/internal/version" @@ -64,13 +65,16 @@ const ( defaultHTTPTimeout = 10 * time.Second // defines the current timeout before giving up with the send process traceCountHeader = "X-Datadog-Trace-Count" // header containing the number of traces in the payload obfuscationVersionHeader = "Datadog-Obfuscation-Version" // header containing the version of obfuscation used, if any + + tracesAPIPath = "/v0.4/traces" + statsAPIPath = "/v0.6/stats" ) // transport is an interface for communicating data to the agent. type transport interface { // send sends the payload p to the agent using the transport set up. // It returns a non-nil response body when no error occurred. - send(p *payloadV04) (body io.ReadCloser, err error) + send(p payload) (body io.ReadCloser, err error) // sendStats sends the given stats payload to the agent. // tracerObfuscationVersion is the version of obfuscation applied (0 if none was applied) sendStats(s *pb.ClientStatsPayload, tracerObfuscationVersion int) error @@ -111,8 +115,8 @@ func newHTTPTransport(url string, client *http.Client) *httpTransport { defaultHeaders["Datadog-External-Env"] = extEnv } return &httpTransport{ - traceURL: fmt.Sprintf("%s/v0.4/traces", url), - statsURL: fmt.Sprintf("%s/v0.6/stats", url), + traceURL: fmt.Sprintf("%s%s", url, tracesAPIPath), + statsURL: fmt.Sprintf("%s%s", url, statsAPIPath), client: client, headers: defaultHeaders, } @@ -135,10 +139,12 @@ func (t *httpTransport) sendStats(p *pb.ClientStatsPayload, tracerObfuscationVer } resp, err := t.client.Do(req) if err != nil { + reportAPIErrorsMetric(resp, err, statsAPIPath) return err } defer resp.Body.Close() if code := resp.StatusCode; code >= 400 { + reportAPIErrorsMetric(resp, err, statsAPIPath) // error, check the body for context information and // return a nice error. msg := make([]byte, 1000) @@ -153,16 +159,17 @@ func (t *httpTransport) sendStats(p *pb.ClientStatsPayload, tracerObfuscationVer return nil } -func (t *httpTransport) send(p *payloadV04) (body io.ReadCloser, err error) { +func (t *httpTransport) send(p payload) (body io.ReadCloser, err error) { req, err := http.NewRequest("POST", t.traceURL, p) if err != nil { return nil, fmt.Errorf("cannot create http request: %s", err.Error()) } - req.ContentLength = int64(p.size()) + stats := p.stats() + req.ContentLength = int64(stats.size) for header, value := range t.headers { req.Header.Set(header, value) } - req.Header.Set(traceCountHeader, strconv.Itoa(p.itemCount())) + req.Header.Set(traceCountHeader, strconv.Itoa(stats.itemCount)) req.Header.Set(headerComputedTopLevel, "yes") if t := getGlobalTracer(); t != nil { tc := t.TracerConf() @@ -186,11 +193,11 @@ func (t *httpTransport) send(p *payloadV04) (body io.ReadCloser, err error) { } response, err := t.client.Do(req) if err != nil { - reportAPIErrorsMetric(response, err) + reportAPIErrorsMetric(response, err, tracesAPIPath) return nil, err } if code := response.StatusCode; code >= 400 { - reportAPIErrorsMetric(response, err) + reportAPIErrorsMetric(response, err, tracesAPIPath) // error, check the body for context information and // return a nice error. msg := make([]byte, 1000) @@ -205,7 +212,7 @@ func (t *httpTransport) send(p *payloadV04) (body io.ReadCloser, err error) { return response.Body, nil } -func reportAPIErrorsMetric(response *http.Response, err error) { +func reportAPIErrorsMetric(response *http.Response, err error, endpoint string) { if t, ok := getGlobalTracer().(*tracer); ok { var reason string if err != nil { @@ -214,7 +221,8 @@ func reportAPIErrorsMetric(response *http.Response, err error) { if response != nil { reason = fmt.Sprintf("server_response_%d", response.StatusCode) } - t.statsd.Incr("datadog.tracer.api.errors", []string{"reason:" + reason}, 1) + tags := []string{"reason:" + reason, "endpoint:" + endpoint} + t.statsd.Incr("datadog.tracer.api.errors", tags, 1) } else { return } diff --git a/ddtrace/tracer/transport_bench_test.go b/ddtrace/tracer/transport_bench_test.go new file mode 100644 index 0000000000..e8b4255b4b --- /dev/null +++ b/ddtrace/tracer/transport_bench_test.go @@ -0,0 +1,101 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package tracer + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +func BenchmarkHTTPTransportSend(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"rate_by_service":{}}`)) + })) + defer server.Close() + + transport := newHTTPTransport(server.URL, defaultHTTPClient(5*time.Second, false)) + + payloadSizes := []struct { + name string + numSpans int + spanSize int + }{ + {"small_1span", 1, 1}, + {"medium_10spans", 10, 1}, + {"large_100spans", 100, 1}, + {"xlarge_1000spans", 1000, 1}, + } + + for _, size := range payloadSizes { + b.Run(size.name, func(b *testing.B) { + payload := newPayload(traceProtocolV04) + spans := make([]*Span, size.numSpans) + for i := 0; i < size.numSpans; i++ { + span := newBasicSpan("transport-test") + span.meta["data"] = strings.Repeat("x", size.spanSize*1024) + spans[i] = span + } + _, _ = payload.push(spans) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + payload.reset() + rc, err := transport.send(payload) + if err == nil { + rc.Close() + } + } + }) + } +} + +func BenchmarkTransportSendConcurrent(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"rate_by_service":{}}`)) + })) + defer server.Close() + + transport := newHTTPTransport(server.URL, defaultHTTPClient(5*time.Second, false)) + concurrencyLevels := []int{1, 2, 4, 8} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("concurrency_%d", concurrency), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + + for j := 0; j < concurrency; j++ { + wg.Add(1) + go func() { + defer wg.Done() + + payload := newPayload(traceProtocolV04) + spans := []*Span{newBasicSpan("concurrent-transport-test")} + _, _ = payload.push(spans) + + rc, err := transport.send(payload) + if err == nil { + rc.Close() + } + }() + } + + wg.Wait() + } + }) + } +} diff --git a/ddtrace/tracer/transport_test.go b/ddtrace/tracer/transport_test.go index 385a8c95b3..0bf5b67e9b 100644 --- a/ddtrace/tracer/transport_test.go +++ b/ddtrace/tracer/transport_test.go @@ -59,10 +59,10 @@ func getTestTrace(traceN, size int) [][]*Span { return traces } -func encode(traces [][]*Span) (*payloadV04, error) { - p := newPayload() +func encode(traces [][]*Span) (payload, error) { + p := newPayload(traceProtocolV04) for _, t := range traces { - if err := p.push(t); err != nil { + if _, err := p.push(t); err != nil { return p, err } } @@ -160,7 +160,7 @@ func TestTransportResponse(t *testing.T) { })) defer srv.Close() transport := newHTTPTransport(srv.URL, defaultHTTPClient(0, false)) - rc, err := transport.send(newPayload()) + rc, err := transport.send(newPayload(traceProtocolV04)) if tt.err != "" { assert.Equal(tt.err, err.Error()) return @@ -271,7 +271,7 @@ func (t *OkTransport) RoundTrip(_ *http.Request) (*http.Response, error) { } func TestApiErrorsMetric(t *testing.T) { - t.Run("error", func(t *testing.T) { + t.Run("traces error", func(t *testing.T) { assert := assert.New(t) c := &http.Client{ Transport: &ErrTransport{}, @@ -291,10 +291,10 @@ func TestApiErrorsMetric(t *testing.T) { calls := statsdtest.FilterCallsByName(tg.IncrCalls(), "datadog.tracer.api.errors") assert.Len(calls, 1) call := calls[0] - assert.Equal([]string{"reason:network_failure"}, call.Tags()) + assert.Equal([]string{"reason:network_failure", "endpoint:" + tracesAPIPath}, call.Tags()) }) - t.Run("response with err code", func(t *testing.T) { + t.Run("traces response with err code", func(t *testing.T) { assert := assert.New(t) c := &http.Client{ Transport: &ErrResponseTransport{}, @@ -314,7 +314,45 @@ func TestApiErrorsMetric(t *testing.T) { calls := statsdtest.FilterCallsByName(tg.IncrCalls(), "datadog.tracer.api.errors") assert.Len(calls, 1) call := calls[0] - assert.Equal([]string{"reason:server_response_400"}, call.Tags()) + assert.Equal([]string{"reason:server_response_400", "endpoint:" + tracesAPIPath}, call.Tags()) + }) + t.Run("stats error", func(t *testing.T) { + assert := assert.New(t) + c := &http.Client{ + Transport: &ErrTransport{}, + } + var tg statsdtest.TestStatsdClient + trc, err := newTracer(WithHTTPClient(c), withStatsdClient(&tg)) + assert.NoError(err) + setGlobalTracer(trc) + defer trc.Stop() + + // We're expecting an error + err = trc.config.transport.sendStats(&pb.ClientStatsPayload{}, 1) + assert.Error(err) + calls := statsdtest.FilterCallsByName(tg.IncrCalls(), "datadog.tracer.api.errors") + assert.Len(calls, 1) + call := calls[0] + assert.Equal([]string{"reason:network_failure", "endpoint:" + statsAPIPath}, call.Tags()) + }) + t.Run("stats response with err code", func(t *testing.T) { + assert := assert.New(t) + c := &http.Client{ + Transport: &ErrResponseTransport{}, + } + var tg statsdtest.TestStatsdClient + trc, err := newTracer(WithHTTPClient(c), withStatsdClient(&tg)) + assert.NoError(err) + setGlobalTracer(trc) + defer trc.Stop() + + err = trc.config.transport.sendStats(&pb.ClientStatsPayload{}, 1) + assert.Error(err) + + calls := statsdtest.FilterCallsByName(tg.IncrCalls(), "datadog.tracer.api.errors") + assert.Len(calls, 1) + call := calls[0] + assert.Equal([]string{"reason:server_response_400", "endpoint:" + statsAPIPath}, call.Tags()) }) t.Run("successful send - no metric", func(t *testing.T) { assert := assert.New(t) diff --git a/ddtrace/tracer/writer.go b/ddtrace/tracer/writer.go index ff2b244231..ea5c1c55b6 100644 --- a/ddtrace/tracer/writer.go +++ b/ddtrace/tracer/writer.go @@ -36,8 +36,11 @@ type agentTraceWriter struct { // config holds the tracer configuration config *config + // mu synchronizes access to payload operations + mu sync.Mutex + // payload encodes and buffers traces in msgpack format - payload *payloadV04 + payload payload // climit limits the number of concurrent outgoing connections climit chan struct{} @@ -67,12 +70,21 @@ func newAgentTraceWriter(c *config, s *prioritySampler, statsdClient globalinter } func (h *agentTraceWriter) add(trace []*Span) { - if err := h.payload.push(trace); err != nil { + h.mu.Lock() + stats, err := h.payload.push(trace) + if err != nil { + h.mu.Unlock() h.statsd.Incr("datadog.tracer.traces_dropped", []string{"reason:encoding_error"}, 1) log.Error("Error encoding msgpack: %s", err.Error()) + return } - atomic.AddUint32(&h.tracesQueued, 1) // TODO: This does not differentiate between complete traces and partial chunks - if h.payload.size() > payloadSizeLimit { + // TODO: This does not differentiate between complete traces and partial chunks + atomic.AddUint32(&h.tracesQueued, 1) + + needsFlush := stats.size > payloadSizeLimit + h.mu.Unlock() + + if needsFlush { h.statsd.Incr("datadog.tracer.flush_triggered", []string{"reason:size"}, 1) h.flush() } @@ -85,22 +97,25 @@ func (h *agentTraceWriter) stop() { } // newPayload returns a new payload based on the trace protocol. -func (h *agentTraceWriter) newPayload() *payloadV04 { - p := newPayload() - p.protocol = h.config.traceProtocol - return p +func (h *agentTraceWriter) newPayload() payload { + return newPayload(h.config.traceProtocol) } // flush will push any currently buffered traces to the server. func (h *agentTraceWriter) flush() { - if h.payload.itemCount() == 0 { + h.mu.Lock() + oldp := h.payload + // Check after acquiring lock + if oldp.itemCount() == 0 { + h.mu.Unlock() return } - h.wg.Add(1) - h.climit <- struct{}{} - oldp := h.payload h.payload = h.newPayload() - go func(p *payloadV04) { + h.mu.Unlock() + + h.climit <- struct{}{} + h.wg.Add(1) + go func(p payload) { defer func(start time.Time) { // Once the payload has been used, clear the buffer for garbage // collection to avoid a memory leak when references to this object @@ -114,17 +129,16 @@ func (h *agentTraceWriter) flush() { h.wg.Done() }(time.Now()) - var count, size int + stats := p.stats() var err error for attempt := 0; attempt <= h.config.sendRetries; attempt++ { - size, count = p.size(), p.itemCount() - log.Debug("Attempt to send payload: size: %d traces: %d\n", size, count) + log.Debug("Attempt to send payload: size: %d traces: %d\n", stats.size, stats.itemCount) var rc io.ReadCloser rc, err = h.config.transport.send(p) if err == nil { log.Debug("sent traces after %d attempts", attempt+1) - h.statsd.Count("datadog.tracer.flush_bytes", int64(size), nil, 1) - h.statsd.Count("datadog.tracer.flush_traces", int64(count), nil, 1) + h.statsd.Count("datadog.tracer.flush_bytes", int64(stats.size), nil, 1) + h.statsd.Count("datadog.tracer.flush_traces", int64(stats.itemCount), nil, 1) if err := h.prioritySampling.readRatesJSON(rc); err != nil { h.statsd.Incr("datadog.tracer.decode_error", nil, 1) } @@ -137,8 +151,8 @@ func (h *agentTraceWriter) flush() { p.reset() time.Sleep(h.config.retryInterval) } - h.statsd.Count("datadog.tracer.traces_dropped", int64(count), []string{"reason:send_failed"}, 1) - log.Error("lost %d traces: %v", count, err.Error()) + h.statsd.Count("datadog.tracer.traces_dropped", int64(stats.itemCount), []string{"reason:send_failed"}, 1) + log.Error("lost %d traces: %v", stats.itemCount, err.Error()) }(oldp) } diff --git a/ddtrace/tracer/writer_bench_test.go b/ddtrace/tracer/writer_bench_test.go new file mode 100644 index 0000000000..f8a5f65f28 --- /dev/null +++ b/ddtrace/tracer/writer_bench_test.go @@ -0,0 +1,123 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package tracer + +import ( + "fmt" + "sync" + "testing" + + "github.com/DataDog/dd-trace-go/v2/internal/statsdtest" + "github.com/stretchr/testify/require" +) + +func BenchmarkAgentTraceWriterAdd(b *testing.B) { + traceSizes := []struct { + name string + numSpans int + }{ + {"1span", 1}, + {"5spans", 5}, + {"10spans", 10}, + {"50spans", 50}, + } + + for _, size := range traceSizes { + b.Run(size.name, func(b *testing.B) { + var statsd statsdtest.TestStatsdClient + cfg, err := newTestConfig() + require.NoError(b, err) + + writer := newAgentTraceWriter(cfg, nil, &statsd) + + trace := make([]*Span, size.numSpans) + for i := 0; i < size.numSpans; i++ { + trace[i] = newBasicSpan("benchmark-span") + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + writer.add(trace) + } + }) + } +} + +func BenchmarkAgentTraceWriterFlush(b *testing.B) { + var statsd statsdtest.TestStatsdClient + cfg, err := newTestConfig() + require.NoError(b, err) + + writer := newAgentTraceWriter(cfg, nil, &statsd) + trace := []*Span{newBasicSpan("flush-test")} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + writer.add(trace) + writer.flush() + writer.wg.Wait() + } +} + +func BenchmarkAgentTraceWriterConcurrent(b *testing.B) { + concurrencyLevels := []int{1, 2, 4, 8} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("concurrency_%d", concurrency), func(b *testing.B) { + var statsd statsdtest.TestStatsdClient + cfg, err := newTestConfig() + require.NoError(b, err) + + writer := newAgentTraceWriter(cfg, nil, &statsd) + trace := []*Span{newBasicSpan("concurrent-test")} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + + for j := 0; j < concurrency; j++ { + wg.Add(1) + go func() { + defer wg.Done() + writer.add(trace) + }() + } + + wg.Wait() + } + }) + } +} + +func BenchmarkAgentTraceWriterStats(b *testing.B) { + var statsd statsdtest.TestStatsdClient + cfg, err := newTestConfig() + require.NoError(b, err) + + writer := newAgentTraceWriter(cfg, nil, &statsd) + + for i := 0; i < 10; i++ { + trace := []*Span{newBasicSpan("stats-test")} + writer.add(trace) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + writer.mu.Lock() + stats := writer.payload.stats() + writer.mu.Unlock() + _ = stats.size + _ = stats.itemCount + } +} diff --git a/ddtrace/tracer/writer_test.go b/ddtrace/tracer/writer_test.go index 53856b061d..74a55558b7 100644 --- a/ddtrace/tracer/writer_test.go +++ b/ddtrace/tracer/writer_test.go @@ -13,6 +13,8 @@ import ( "io" "math" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -345,7 +347,7 @@ type failingTransport struct { assert *assert.Assertions } -func (t *failingTransport) send(p *payloadV04) (io.ReadCloser, error) { +func (t *failingTransport) send(p payload) (io.ReadCloser, error) { t.sendAttempts++ traces, err := decode(p) @@ -419,7 +421,7 @@ func TestTraceWriterFlushRetries(t *testing.T) { assert.Nil(err) var statsd statsdtest.TestStatsdClient - h := newAgentTraceWriter(c, nil, &statsd) + h := newAgentTraceWriter(c, newPrioritySampler(), &statsd) h.add(ss) start := time.Now() h.flush() @@ -457,7 +459,7 @@ func TestTraceProtocol(t *testing.T) { cfg, err := newTestConfig() require.NoError(t, err) h := newAgentTraceWriter(cfg, nil, nil) - assert.Equal(traceProtocolV1, h.payload.protocol) + assert.Equal(traceProtocolV1, h.payload.protocol()) }) t.Run("v0.4", func(t *testing.T) { @@ -465,14 +467,14 @@ func TestTraceProtocol(t *testing.T) { cfg, err := newTestConfig() require.NoError(t, err) h := newAgentTraceWriter(cfg, nil, nil) - assert.Equal(traceProtocolV04, h.payload.protocol) + assert.Equal(traceProtocolV04, h.payload.protocol()) }) t.Run("default", func(t *testing.T) { cfg, err := newTestConfig() require.NoError(t, err) h := newAgentTraceWriter(cfg, nil, nil) - assert.Equal(traceProtocolV04, h.payload.protocol) + assert.Equal(traceProtocolV04, h.payload.protocol()) }) t.Run("invalid", func(t *testing.T) { @@ -480,7 +482,7 @@ func TestTraceProtocol(t *testing.T) { cfg, err := newTestConfig() require.NoError(t, err) h := newAgentTraceWriter(cfg, nil, nil) - assert.Equal(traceProtocolV04, h.payload.protocol) + assert.Equal(traceProtocolV04, h.payload.protocol()) }) } func BenchmarkJsonEncodeSpan(b *testing.B) { @@ -503,3 +505,148 @@ func BenchmarkJsonEncodeFloat(b *testing.B) { encodeFloat(bs, float64(1e-9)) } } + +func TestAgentWriterRaceCondition(t *testing.T) { + // This test reproduces a race condition between add() and flush() operations + // The race occurs when: + // 1. add() loads payload, flush() replaces it before add() can push to it + // 2. add() increments tracesQueued while flush() goroutine resets it to 0 + // + // Run with: go test -race -run TestAgentWriterRaceCondition + + assert := assert.New(t) + var tg statsdtest.TestStatsdClient + cfg, err := newTestConfig(withStatsdClient(&tg)) + require.NoError(t, err) + statsd, err := newStatsdClient(cfg) + require.NoError(t, err) + defer statsd.Close() + + writer := newAgentTraceWriter(cfg, newPrioritySampler(), &tg) + + const numGoroutines = 50 + const numOperations = 100 + + // Channel to coordinate goroutines + start := make(chan struct{}) + + var wg sync.WaitGroup + + // Spawn goroutines that continuously add traces + for i := 0; i < numGoroutines/2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start // Wait for coordination signal + + for j := 0; j < numOperations; j++ { + spans := []*Span{makeSpan(1)} + writer.add(spans) + } + }() + } + + // Spawn goroutines that continuously flush + for i := 0; i < numGoroutines/2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start // Wait for coordination signal + + for j := 0; j < numOperations; j++ { + writer.flush() + } + }() + } + + // Start all goroutines simultaneously to maximize race condition probability + close(start) + + // Wait for all operations to complete + wg.Wait() + + // Final flush to process any remaining traces + writer.flush() + writer.wg.Wait() + + // The race condition might cause: + // 1. Traces to be lost (added to old payload after it was flushed) + // 2. Incorrect trace counts due to counter races + // 3. Data races detected by Go's race detector + + assert.True(true, "Test completed - check for race conditions with -race flag") +} + +func TestAgentWriterTraceCountAccuracy(t *testing.T) { + // This test validates that trace counting remains accurate under concurrent operations + // It detects both data races and logical errors in trace counting + + assert := assert.New(t) + var tg statsdtest.TestStatsdClient + cfg, err := newTestConfig(withStatsdClient(&tg)) + require.NoError(t, err) + statsd, err := newStatsdClient(cfg) + require.NoError(t, err) + defer statsd.Close() + + writer := newAgentTraceWriter(cfg, newPrioritySampler(), &tg) + + const numAddGoroutines = 20 + const numFlushGoroutines = 10 + const numTracesPerGoroutine = 50 + const expectedTotalTraces = numAddGoroutines * numTracesPerGoroutine + + start := make(chan struct{}) + var wg sync.WaitGroup + + // Track traces added for verification + var tracesAdded int32 + + // Spawn goroutines that add traces + for i := 0; i < numAddGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + + for j := 0; j < numTracesPerGoroutine; j++ { + spans := []*Span{makeSpan(1)} + writer.add(spans) + atomic.AddInt32(&tracesAdded, 1) + } + }() + } + + // Spawn goroutines that flush occasionally + for i := 0; i < numFlushGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + + // Flush periodically while adds are happening + for j := 0; j < 10; j++ { + time.Sleep(time.Microsecond * 100) + writer.flush() + } + }() + } + + // Start all goroutines + close(start) + wg.Wait() + + // Final flush to ensure all traces are processed + writer.flush() + writer.wg.Wait() + + // Verify that the number of traces added matches our expectation + actualTracesAdded := atomic.LoadInt32(&tracesAdded) + assert.Equal(int32(expectedTotalTraces), actualTracesAdded, + "Expected %d traces to be added, but got %d", expectedTotalTraces, actualTracesAdded) + + // The race condition could cause: + // 1. Loss of traces if they're added to an old payload after flush starts + // 2. Incorrect metrics reporting due to counter races + // 3. Data corruption in payload structures +} diff --git a/profiler/profiler.go b/profiler/profiler.go index 06895a35c9..485e750187 100644 --- a/profiler/profiler.go +++ b/profiler/profiler.go @@ -259,7 +259,7 @@ func newProfiler(opts ...Option) (*profiler, error) { types = append(types, executionTrace) } for _, pt := range types { - isDelta := len(profileTypes[pt].DeltaValues) > 0 + isDelta := p.cfg.deltaProfiles && len(profileTypes[pt].DeltaValues) > 0 in, out := compressionStrategy(pt, isDelta, p.cfg.compressionConfig) compressor, err := newCompressionPipeline(in, out) if err != nil { diff --git a/profiler/profiler_test.go b/profiler/profiler_test.go index 2914c8aee6..87bf548157 100644 --- a/profiler/profiler_test.go +++ b/profiler/profiler_test.go @@ -7,6 +7,7 @@ package profiler import ( "bytes" + "compress/gzip" "context" "encoding/json" "fmt" @@ -35,6 +36,7 @@ import ( "github.com/DataDog/dd-trace-go/v2/internal/traceprof" "github.com/DataDog/dd-trace-go/v2/internal/version" + pprofile "github.com/google/pprof/profile" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -900,3 +902,42 @@ func TestMetricsProfileStopEarlyNoLog(t *testing.T) { } } } + +func gzipDecompress(data []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + return io.ReadAll(r) +} + +func TestHeapProfileCompression(t *testing.T) { + t.Run("delta", func(t *testing.T) { testHeapProfileCompression(t, true) }) + t.Run("non-delta", func(t *testing.T) { testHeapProfileCompression(t, false) }) +} + +func testHeapProfileCompression(t *testing.T, delta bool) { + profiles := startTestProfiler(t, 1, + WithPeriod(10*time.Millisecond), WithProfileTypes(HeapProfile), WithDeltaProfiles(delta), + ) + p := <-profiles + attachment := "heap.pprof" + if delta { + attachment = "delta-heap.pprof" + } + data, ok := p.attachments[attachment] + if !ok { + t.Fatalf("no heap profile, got %s", p.event.Attachments) + } + decompressed, err := gzipDecompress(data) + if err != nil { + t.Fatalf("decompressing the heap profile failed: %s", err) + } + t.Logf("%x", decompressed[:16]) + // We assume the profile is gzip compressed. The pprof pacakge + // can parse gzip-compressed profiles (it checks for the magic number). + // So we should be able to parse the original data + if _, err := pprofile.ParseData(data); err != nil { + t.Fatalf("parsing profile data failed: %s", err) + } +}