diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 0aa38613..d38bb7d3 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -96,6 +96,167 @@ def _wrap(*args, **kwargs): return _wrapper +_test_extract_dd_trace_context = ( + ("api-gateway", Context(trace_id=12345, span_id=67890, sampling_priority=2)), + ( + "api-gateway-no-apiid", + Context(trace_id=12345, span_id=67890, sampling_priority=2), + ), + ( + "api-gateway-non-proxy", + Context(trace_id=12345, span_id=67890, sampling_priority=2), + ), + ( + "api-gateway-non-proxy-async", + Context(trace_id=12345, span_id=67890, sampling_priority=2), + ), + ( + "api-gateway-websocket-connect", + Context(trace_id=12345, span_id=67890, sampling_priority=2), + ), + ( + "api-gateway-websocket-default", + Context(trace_id=12345, span_id=67890, sampling_priority=2), + ), + ( + "api-gateway-websocket-disconnect", + Context(trace_id=12345, span_id=67890, sampling_priority=2), + ), + ( + "authorizer-request-api-gateway-v1", + Context( + trace_id=13478705995797221209, + span_id=8471288263384216896, + sampling_priority=1, + ), + ), + ("authorizer-request-api-gateway-v1-cached", None), + ( + "authorizer-request-api-gateway-v2", + Context( + trace_id=14356983619852933354, + span_id=12658621083505413809, + sampling_priority=1, + ), + ), + ("authorizer-request-api-gateway-v2-cached", None), + ( + "authorizer-request-api-gateway-websocket-connect", + Context( + trace_id=5351047404834723189, + span_id=18230460631156161837, + sampling_priority=1, + ), + ), + ("authorizer-request-api-gateway-websocket-message", None), + ( + "authorizer-token-api-gateway-v1", + Context( + trace_id=17874798268144902712, + span_id=16184667399315372101, + sampling_priority=1, + ), + ), + ("authorizer-token-api-gateway-v1-cached", None), + ("cloudfront", None), + ("cloudwatch-events", None), + ("cloudwatch-logs", None), + ("custom", None), + ("dynamodb", None), + ("eventbridge-custom", Context(trace_id=12345, span_id=67890, sampling_priority=2)), + ( + "eventbridge-sqs", + Context( + trace_id=7379586022458917877, + span_id=2644033662113726488, + sampling_priority=1, + ), + ), + ("http-api", Context(trace_id=12345, span_id=67890, sampling_priority=2)), + ( + "kinesis", + Context( + trace_id=4948377316357291421, + span_id=2876253380018681026, + sampling_priority=1, + ), + ), + ( + "kinesis-batch", + Context( + trace_id=4948377316357291421, + span_id=2876253380018681026, + sampling_priority=1, + ), + ), + ("lambda-url", None), + ("s3", None), + ( + "sns-b64-msg-attribute", + Context( + trace_id=4948377316357291421, + span_id=6746998015037429512, + sampling_priority=1, + ), + ), + ( + "sns-batch", + Context( + trace_id=4948377316357291421, + span_id=6746998015037429512, + sampling_priority=1, + ), + ), + ( + "sns-string-msg-attribute", + Context( + trace_id=4948377316357291421, + span_id=6746998015037429512, + sampling_priority=1, + ), + ), + ( + "sqs-batch", + Context( + trace_id=2684756524522091840, + span_id=7431398482019833808, + sampling_priority=1, + ), + ), + ( + "sqs-java-upstream", + Context( + trace_id=7925498337868555493, + span_id=5245570649555658903, + sampling_priority=1, + ), + ), + ( + "sqs-string-msg-attribute", + Context( + trace_id=2684756524522091840, + span_id=7431398482019833808, + sampling_priority=1, + ), + ), + ({"headers": None}, None), +) + + +@pytest.mark.parametrize("event,expect", _test_extract_dd_trace_context) +def test_extract_dd_trace_context(event, expect): + if isinstance(event, str): + with open(f"{event_samples}{event}.json") as f: + event = json.load(f) + ctx = get_mock_context() + + actual, _, _ = extract_dd_trace_context(event, ctx) + assert (expect is None) is (actual is None) + assert (expect is None) or actual.trace_id == expect.trace_id + assert (expect is None) or actual.span_id == expect.span_id + assert (expect is None) or actual.sampling_priority == expect.sampling_priority + + class TestExtractAndGetDDTraceContext(unittest.TestCase): def setUp(self): global dd_tracing_enabled @@ -1773,127 +1934,6 @@ def test_create_inferred_span(mock_span_finish, source, expect): class TestInferredSpans(unittest.TestCase): - def test_extract_context_from_eventbridge_event(self): - event_sample_source = "eventbridge-custom" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 12345) - self.assertEqual(context.span_id, 67890), - self.assertEqual(context.sampling_priority, 2) - - def test_extract_dd_trace_context_for_eventbridge(self): - event_sample_source = "eventbridge-custom" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 12345) - self.assertEqual(context.span_id, 67890) - - def test_extract_context_from_eventbridge_sqs_event(self): - event_sample_source = "eventbridge-sqs" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - - ctx = get_mock_context() - context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 7379586022458917877) - self.assertEqual(context.span_id, 2644033662113726488) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_sqs_event_with_string_msg_attr(self): - event_sample_source = "sqs-string-msg-attribute" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 2684756524522091840) - self.assertEqual(context.span_id, 7431398482019833808) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_sqs_batch_event(self): - event_sample_source = "sqs-batch" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 2684756524522091840) - self.assertEqual(context.span_id, 7431398482019833808) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_sqs_java_upstream_event(self): - event_sample_source = "sqs-java-upstream" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_type = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 7925498337868555493) - self.assertEqual(context.span_id, 5245570649555658903) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_sns_event_with_string_msg_attr(self): - event_sample_source = "sns-string-msg-attribute" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 4948377316357291421) - self.assertEqual(context.span_id, 6746998015037429512) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_sns_event_with_b64_msg_attr(self): - event_sample_source = "sns-b64-msg-attribute" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 4948377316357291421) - self.assertEqual(context.span_id, 6746998015037429512) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_sns_batch_event(self): - event_sample_source = "sns-batch" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 4948377316357291421) - self.assertEqual(context.span_id, 6746998015037429512) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_kinesis_event(self): - event_sample_source = "kinesis" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 4948377316357291421) - self.assertEqual(context.span_id, 2876253380018681026) - self.assertEqual(context.sampling_priority, 1) - - def test_extract_context_from_kinesis_batch_event(self): - event_sample_source = "kinesis-batch" - test_file = event_samples + event_sample_source + ".json" - with open(test_file, "r") as event: - event = json.load(event) - ctx = get_mock_context() - context, source, event_source = extract_dd_trace_context(event, ctx) - self.assertEqual(context.trace_id, 4948377316357291421) - self.assertEqual(context.span_id, 2876253380018681026) - self.assertEqual(context.sampling_priority, 1) - @patch("datadog_lambda.tracing.submit_errors_metric") def test_mark_trace_as_error_for_5xx_responses_getting_400_response_code( self, mock_submit_errors_metric @@ -1915,14 +1955,6 @@ def test_mark_trace_as_error_for_5xx_responses_sends_error_metric_and_set_error_ mock_submit_errors_metric.assert_called_once() self.assertEqual(1, mock_span.error) - def test_no_error_with_nonetype_headers(self): - lambda_ctx = get_mock_context() - ctx, source, event_type = extract_dd_trace_context( - {"headers": None}, - lambda_ctx, - ) - self.assertEqual(ctx, None) - class TestStepFunctionsTraceContext(unittest.TestCase): def test_deterministic_m5_hash(self):