diff --git a/datadog_lambda/tracing.py b/datadog_lambda/tracing.py index 51157f6a..aa94a20e 100644 --- a/datadog_lambda/tracing.py +++ b/datadog_lambda/tracing.py @@ -237,7 +237,6 @@ def extract_context_from_sqs_or_sns_event_or_context( Falls back to lambda context if no trace data is found in the SQS message attributes. Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported. """ - source_arn = "" event_type = "sqs" if event_source.equals(EventTypes.SQS) else "sns" # EventBridge => SQS @@ -248,91 +247,108 @@ def extract_context_from_sqs_or_sns_event_or_context( except Exception: logger.debug("Failed extracting context as EventBridge to SQS.") - try: - first_record = event.get("Records")[0] - source_arn = first_record.get("eventSourceARN", "") + apm_context: Context = None + for record in event.get("Records", []): + source_arn = ( + record.get("eventSourceARN") + if event_type == "sqs" + else record.get("Sns", {}).get("TopicArn") + ) + dd_ctx = None + try: + dd_ctx = _extract_context_from_sqs_or_sns_record(record) + if apm_context is None: + apm_context = _extract_apm_context(dd_ctx, record) + except Exception as e: + logger.debug("The trace extractor returned with error %s", e) + if config.data_streams_enabled: + _dsm_set_checkpoint(dd_ctx, event_type, source_arn) + if not config.data_streams_enabled: + break + + return ( + apm_context + if apm_context + else extract_context_from_lambda_context(lambda_context) + ) - # logic to deal with SNS => SQS event - if "body" in first_record: - body_str = first_record.get("body") + +def _extract_context_from_sqs_or_sns_record(record): + # logic to deal with SNS => SQS event + if "body" in record: + body_str = record.get("body") + try: + body = json.loads(body_str) + if body.get("Type", "") == "Notification" and "TopicArn" in body: + logger.debug("Found SNS message inside SQS event") + record = get_first_record(create_sns_event(body)) + except Exception: + pass + + msg_attributes = record.get("messageAttributes") + if msg_attributes is None: + sns_record = record.get("Sns") or {} + msg_attributes = sns_record.get("MessageAttributes") or {} + dd_payload = msg_attributes.get("_datadog") + if dd_payload: + # SQS uses dataType and binaryValue/stringValue + # SNS uses Type and Value + dd_json_data = None + dd_json_data_type = dd_payload.get("Type") or dd_payload.get("dataType") + if dd_json_data_type == "Binary": + import base64 + + dd_json_data = dd_payload.get("binaryValue") or dd_payload.get("Value") + if dd_json_data: + dd_json_data = base64.b64decode(dd_json_data) + elif dd_json_data_type == "String": + dd_json_data = dd_payload.get("stringValue") or dd_payload.get("Value") + else: + logger.debug( + "Datadog Lambda Python only supports extracting trace" + "context from String or Binary SQS/SNS message attributes" + ) + + if dd_json_data: + dd_data = json.loads(dd_json_data) + return dd_data + return None + + +def _extract_apm_context(dd_ctx, record): + if dd_ctx: + if is_step_function_event(dd_ctx): try: - body = json.loads(body_str) - if body.get("Type", "") == "Notification" and "TopicArn" in body: - logger.debug("Found SNS message inside SQS event") - first_record = get_first_record(create_sns_event(body)) + return extract_context_from_step_functions(dd_ctx, None) except Exception: - pass - - msg_attributes = first_record.get("messageAttributes") - if msg_attributes is None: - sns_record = first_record.get("Sns") or {} - # SNS->SQS event would extract SNS arn without this check - if event_source.equals(EventTypes.SNS): - source_arn = sns_record.get("TopicArn", "") - msg_attributes = sns_record.get("MessageAttributes") or {} - dd_payload = msg_attributes.get("_datadog") - if dd_payload: - # SQS uses dataType and binaryValue/stringValue - # SNS uses Type and Value - dd_json_data = None - dd_json_data_type = dd_payload.get("Type") or dd_payload.get("dataType") - if dd_json_data_type == "Binary": - import base64 - - dd_json_data = dd_payload.get("binaryValue") or dd_payload.get("Value") - if dd_json_data: - dd_json_data = base64.b64decode(dd_json_data) - elif dd_json_data_type == "String": - dd_json_data = dd_payload.get("stringValue") or dd_payload.get("Value") - else: logger.debug( - "Datadog Lambda Python only supports extracting trace" - "context from String or Binary SQS/SNS message attributes" + "Failed to extract Step Functions context from SQS/SNS event." ) - - if dd_json_data: - dd_data = json.loads(dd_json_data) - - if is_step_function_event(dd_data): - try: - return extract_context_from_step_functions(dd_data, None) - except Exception: - logger.debug( - "Failed to extract Step Functions context from SQS/SNS event." - ) - context = propagator.extract(dd_data) - _dsm_set_checkpoint(dd_data, event_type, source_arn) - return context else: - # Handle case where trace context is injected into attributes.AWSTraceHeader - # example: Root=1-654321ab-000000001234567890abcdef;Parent=0123456789abcdef;Sampled=1 - attrs = event.get("Records")[0].get("attributes") - if attrs: - x_ray_header = attrs.get("AWSTraceHeader") - if x_ray_header: - x_ray_context = parse_xray_header(x_ray_header) - trace_id_parts = x_ray_context.get("trace_id", "").split("-") - if len(trace_id_parts) > 2 and trace_id_parts[2].startswith( - DD_TRACE_JAVA_TRACE_ID_PADDING - ): - # If it starts with eight 0's padding, - # then this AWSTraceHeader contains Datadog injected trace context - logger.debug( - "Found dd-trace injected trace context from AWSTraceHeader" - ) - return Context( - trace_id=int(trace_id_parts[2][8:], 16), - span_id=int(x_ray_context["parent_id"], 16), - sampling_priority=float(x_ray_context["sampled"]), - ) - # Still want to set a DSM checkpoint even if DSM context not propagated - _dsm_set_checkpoint(None, event_type, source_arn) - return extract_context_from_lambda_context(lambda_context) - except Exception as e: - logger.debug("The trace extractor returned with error %s", e) - # Still want to set a DSM checkpoint even if DSM context not propagated - _dsm_set_checkpoint(None, event_type, source_arn) - return extract_context_from_lambda_context(lambda_context) + return propagator.extract(dd_ctx) + else: + # Handle case where trace context is injected into attributes.AWSTraceHeader + # example: Root=1-654321ab-000000001234567890abcdef;Parent=0123456789abcdef;Sampled=1 + attrs = record.get("attributes") + if attrs: + x_ray_header = attrs.get("AWSTraceHeader") + if x_ray_header: + x_ray_context = parse_xray_header(x_ray_header) + trace_id_parts = x_ray_context.get("trace_id", "").split("-") + if len(trace_id_parts) > 2 and trace_id_parts[2].startswith( + DD_TRACE_JAVA_TRACE_ID_PADDING + ): + # If it starts with eight 0's padding, + # then this AWSTraceHeader contains Datadog injected trace context + logger.debug( + "Found dd-trace injected trace context from AWSTraceHeader" + ) + return Context( + trace_id=int(trace_id_parts[2][8:], 16), + span_id=int(x_ray_context["parent_id"], 16), + sampling_priority=float(x_ray_context["sampled"]), + ) + return None def _extract_context_from_eventbridge_sqs_event(event): @@ -392,31 +408,37 @@ def extract_context_from_kinesis_event(event, lambda_context): Extract datadog trace context from a Kinesis Stream's base64 encoded data string Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported. """ - source_arn = "" - try: - record = get_first_record(event) - source_arn = record.get("eventSourceARN", "") - kinesis = record.get("kinesis") - if not kinesis: - return extract_context_from_lambda_context(lambda_context) - data = kinesis.get("data") - if data: - import base64 - b64_bytes = data.encode("ascii") - str_bytes = base64.b64decode(b64_bytes) - data_str = str_bytes.decode("ascii") - data_obj = json.loads(data_str) - dd_ctx = data_obj.get("_datadog") - if dd_ctx: - context = propagator.extract(dd_ctx) - _dsm_set_checkpoint(dd_ctx, "kinesis", source_arn) - return context - except Exception as e: - logger.debug("The trace extractor returned with error %s", e) - # Still want to set a DSM checkpoint even if DSM context not propagated - _dsm_set_checkpoint(None, "kinesis", source_arn) - return extract_context_from_lambda_context(lambda_context) + apm_context: Context = None + for record in event.get("Records", []): + dd_ctx = None + try: + source_arn = record.get("eventSourceARN", "") + kinesis = record.get("kinesis") + if not kinesis: + return extract_context_from_lambda_context(lambda_context) + data = kinesis.get("data") + if data: + import base64 + + b64_bytes = data.encode("ascii") + str_bytes = base64.b64decode(b64_bytes) + data_str = str_bytes.decode("ascii") + data_obj = json.loads(data_str) + dd_ctx = data_obj.get("_datadog") + if dd_ctx and apm_context is None: + apm_context = propagator.extract(dd_ctx) + except Exception as e: + logger.debug("The trace extractor returned with error %s", e) + if config.data_streams_enabled: + _dsm_set_checkpoint(dd_ctx, "kinesis", source_arn) + if not config.data_streams_enabled: + break + return ( + apm_context + if apm_context + else extract_context_from_lambda_context(lambda_context) + ) def _deterministic_sha256_hash(s: str, part: str) -> int: diff --git a/tests/test_tracing.py b/tests/test_tracing.py index c87a0971..a149c5ee 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -2694,6 +2694,122 @@ def test_sqs_invalid_datadog_message_attribute(self, mock_logger): # None indicates no DSM context propagation self.assertEqual(carrier_get("dd-pathway-ctx-base64"), None) + def test_sqs_batch_processing(self): + dd_data_1 = {"dd-pathway-ctx-base64": "record1"} + dd_data_2 = {"dd-pathway-ctx-base64": "record2"} + dd_json_data_1 = json.dumps(dd_data_1) + dd_json_data_2 = json.dumps(dd_data_2) + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": { + "dataType": "String", + "stringValue": dd_json_data_1, + } + }, + "eventSource": "aws:sqs", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": { + "dataType": "String", + "stringValue": dd_json_data_2, + } + }, + "eventSource": "aws:sqs", + }, + ] + } + + extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + self.assertEqual(self.mock_checkpoint.call_count, 2) + + args_1, _ = self.mock_checkpoint.call_args_list[0] + self.assertEqual(args_1[0], "sqs") + self.assertEqual(args_1[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_1 = args_1[2] + self.assertEqual(carrier_get_1("dd-pathway-ctx-base64"), "record1") + + args_2, _ = self.mock_checkpoint.call_args_list[1] + self.assertEqual(args_2[0], "sqs") + self.assertEqual(args_2[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_2 = args_2[2] + self.assertEqual(carrier_get_2("dd-pathway-ctx-base64"), "record2") + + def test_sqs_batch_processing_with_invalid_records(self): + dd_data_1 = {"dd-pathway-ctx-base64": "valid_record"} + dd_json_data_1 = json.dumps(dd_data_1) + + dd_data_3 = {"dd-pathway-ctx-base64": "another_valid_record"} + dd_json_data_3 = json.dumps(dd_data_3) + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": { + "dataType": "String", + "stringValue": dd_json_data_1, + } + }, + "eventSource": "aws:sqs", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": { + "dataType": "Binary", + # This will cause extraction to fail + "binaryValue": "invalid-base64-data", + } + }, + "eventSource": "aws:sqs", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "messageAttributes": { + "_datadog": { + "dataType": "String", + "stringValue": dd_json_data_3, + } + }, + "eventSource": "aws:sqs", + }, + ] + } + + extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + self.assertEqual(self.mock_checkpoint.call_count, 3) + + args_1, _ = self.mock_checkpoint.call_args_list[0] + self.assertEqual(args_1[0], "sqs") + self.assertEqual(args_1[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_1 = args_1[2] + self.assertEqual(carrier_get_1("dd-pathway-ctx-base64"), "valid_record") + + args_2, _ = self.mock_checkpoint.call_args_list[1] + self.assertEqual(args_2[0], "sqs") + self.assertEqual(args_2[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_2 = args_2[2] + self.assertEqual(carrier_get_2("dd-pathway-ctx-base64"), None) + + args_3, _ = self.mock_checkpoint.call_args_list[2] + self.assertEqual(args_3[0], "sqs") + self.assertEqual(args_3[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_3 = args_3[2] + self.assertEqual(carrier_get_3("dd-pathway-ctx-base64"), "another_valid_record") + def test_sqs_source_arn_not_found(self): event = { "Records": [ @@ -3212,6 +3328,137 @@ def test_sns_to_sqs_invalid_datadog_message_attribute(self, mock_logger): # None indicates no DSM context propagation self.assertEqual(carrier_get("dd-pathway-ctx-base64"), None) + def test_sns_to_sqs_batch_processing(self): + dd_data_1 = {"dd-pathway-ctx-base64": "record1"} + dd_data_2 = {"dd-pathway-ctx-base64": "record2"} + dd_json_data_1 = json.dumps(dd_data_1) + dd_json_data_2 = json.dumps(dd_data_2) + + sns_message_1 = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data_1} + }, + } + sns_message_2 = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data_2} + }, + } + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_message_1), + "eventSource": "aws:sqs", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_message_2), + "eventSource": "aws:sqs", + }, + ] + } + + extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + self.assertEqual(self.mock_checkpoint.call_count, 2) + + args_1, _ = self.mock_checkpoint.call_args_list[0] + self.assertEqual(args_1[0], "sqs") + self.assertEqual(args_1[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_1 = args_1[2] + self.assertEqual(carrier_get_1("dd-pathway-ctx-base64"), "record1") + + args_2, _ = self.mock_checkpoint.call_args_list[1] + self.assertEqual(args_2[0], "sqs") + self.assertEqual(args_2[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_2 = args_2[2] + self.assertEqual(carrier_get_2("dd-pathway-ctx-base64"), "record2") + + def test_sns_to_sqs_batch_processing_with_invalid_records(self): + dd_data_1 = {"dd-pathway-ctx-base64": "valid_sns_record"} + dd_json_data_1 = json.dumps(dd_data_1) + + sns_message_1 = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data_1} + }, + } + + sns_message_2 = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "Binary", "Value": "invalid-base64-data"} + }, + } + + dd_data_3 = {"dd-pathway-ctx-base64": "another_valid_sns_record"} + dd_json_data_3 = json.dumps(dd_data_3) + + sns_message_3 = { + "Type": "Notification", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic", + "MessageAttributes": { + "_datadog": {"Type": "String", "Value": dd_json_data_3} + }, + } + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_message_1), + "eventSource": "aws:sqs", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_message_2), + "eventSource": "aws:sqs", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "body": json.dumps(sns_message_3), + "eventSource": "aws:sqs", + }, + ] + } + + extract_context_from_sqs_or_sns_event_or_context( + event, self.lambda_context, parse_event_source(event) + ) + + self.assertEqual(self.mock_checkpoint.call_count, 3) + + args_1, _ = self.mock_checkpoint.call_args_list[0] + self.assertEqual(args_1[0], "sqs") + self.assertEqual(args_1[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_1 = args_1[2] + self.assertEqual(carrier_get_1("dd-pathway-ctx-base64"), "valid_sns_record") + + args_2, _ = self.mock_checkpoint.call_args_list[1] + self.assertEqual(args_2[0], "sqs") + self.assertEqual(args_2[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_2 = args_2[2] + self.assertEqual(carrier_get_2("dd-pathway-ctx-base64"), None) + + args_3, _ = self.mock_checkpoint.call_args_list[2] + self.assertEqual(args_3[0], "sqs") + self.assertEqual(args_3[1], "arn:aws:sqs:us-east-1:123456789012:test-queue") + carrier_get_3 = args_3[2] + self.assertEqual( + carrier_get_3("dd-pathway-ctx-base64"), "another_valid_sns_record" + ) + def test_sns_to_sqs_source_arn_not_found(self): sns_notification = { "Type": "Notification", @@ -3379,6 +3626,109 @@ def test_kinesis_invalid_datadog_message_attribute(self, mock_logger): # None indicates no DSM context propagation self.assertEqual(carrier_get("dd-pathway-ctx-base64"), None) + def test_kinesis_batch_processing(self): + dd_data_1 = {"dd-pathway-ctx-base64": "record1"} + dd_data_2 = {"dd-pathway-ctx-base64": "record2"} + + kinesis_data_1 = {"_datadog": dd_data_1, "message": "test1"} + kinesis_data_2 = {"_datadog": dd_data_2, "message": "test2"} + + encoded_data_1 = base64.b64encode(json.dumps(kinesis_data_1).encode()).decode() + encoded_data_2 = base64.b64encode(json.dumps(kinesis_data_2).encode()).decode() + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data_1}, + }, + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data_2}, + }, + ] + } + + extract_context_from_kinesis_event(event, self.lambda_context) + + self.assertEqual(self.mock_checkpoint.call_count, 2) + + # Verify first record call + args_1, _ = self.mock_checkpoint.call_args_list[0] + self.assertEqual(args_1[0], "kinesis") + self.assertEqual( + args_1[1], "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream" + ) + carrier_get_1 = args_1[2] + self.assertEqual(carrier_get_1("dd-pathway-ctx-base64"), "record1") + + # Verify second record call + args_2, _ = self.mock_checkpoint.call_args_list[1] + self.assertEqual(args_2[0], "kinesis") + self.assertEqual( + args_2[1], "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream" + ) + carrier_get_2 = args_2[2] + self.assertEqual(carrier_get_2("dd-pathway-ctx-base64"), "record2") + + def test_kinesis_batch_processing_with_invalid_records(self): + dd_data_1 = {"dd-pathway-ctx-base64": "valid_kinesis_record"} + kinesis_data_1 = {"_datadog": dd_data_1, "message": "test1"} + encoded_data_1 = base64.b64encode(json.dumps(kinesis_data_1).encode()).decode() + + dd_data_3 = {"dd-pathway-ctx-base64": "another_valid_kinesis_record"} + kinesis_data_3 = {"_datadog": dd_data_3, "message": "test3"} + encoded_data_3 = base64.b64encode(json.dumps(kinesis_data_3).encode()).decode() + + event = { + "Records": [ + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data_1}, + }, + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": { + "data": "invalid-base64-data" + }, # This will cause extraction to fail + }, + { + "eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream", + "kinesis": {"data": encoded_data_3}, + }, + ] + } + + extract_context_from_kinesis_event(event, self.lambda_context) + + self.assertEqual(self.mock_checkpoint.call_count, 3) + + args_1, _ = self.mock_checkpoint.call_args_list[0] + self.assertEqual(args_1[0], "kinesis") + self.assertEqual( + args_1[1], "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream" + ) + carrier_get_1 = args_1[2] + self.assertEqual(carrier_get_1("dd-pathway-ctx-base64"), "valid_kinesis_record") + + args_2, _ = self.mock_checkpoint.call_args_list[1] + self.assertEqual(args_2[0], "kinesis") + self.assertEqual( + args_2[1], "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream" + ) + carrier_get_2 = args_2[2] + self.assertEqual(carrier_get_2("dd-pathway-ctx-base64"), None) + + args_3, _ = self.mock_checkpoint.call_args_list[2] + self.assertEqual(args_3[0], "kinesis") + self.assertEqual( + args_3[1], "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream" + ) + carrier_get_3 = args_3[2] + self.assertEqual( + carrier_get_3("dd-pathway-ctx-base64"), "another_valid_kinesis_record" + ) + def test_kinesis_source_arn_not_found(self): kinesis_data = {"message": "test"} kinesis_data_str = json.dumps(kinesis_data)