Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tests/agent_unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
)


FAILING_ON_WINDOWS = pytest.mark.xfail(sys.platform == "win32", reason="TODO: Fix this test on Windows")


class FakeProtos:
Span = object()
SpanBatch = object()
Expand Down
85 changes: 41 additions & 44 deletions tests/agent_unittests/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from io import StringIO

import pytest
from conftest import FAILING_ON_WINDOWS
from testing_support.certs import CERT_PATH
from testing_support.mock_external_http_server import MockExternalHTTPServer

Expand All @@ -43,18 +42,29 @@

def echo_full_request(self):
self.server.connections.append(self.connection)
request_line = str(self.requestline).encode("utf-8")
headers = "\n".join(f"{k.lower()}: {v}" for k, v in self.headers.items())
self.send_response(200)
self.end_headers()
self.wfile.write(request_line)
self.wfile.write(b"\n")
self.wfile.write(headers.strip().encode("utf-8"))
self.wfile.write(b"\n")
request_line = str(self.requestline)
headers = list(self.headers.items())

content_length = int(self.headers.get("Content-Length", 0))
if content_length:
data = self.rfile.read(content_length)
self.wfile.write(data)
body = self.rfile.read(content_length).hex()
else:
body = ""

payload = [request_line, headers, body]
payload = json.dumps(payload).encode("utf-8")

self.send_response(200)
self.end_headers()
self.wfile.write(payload)


def decode_payload(data):
if isinstance(data, bytes):
data = data.decode("utf-8")
payload = json.loads(data)
payload[2] = bytes.fromhex(payload[2]) # Convert body back to bytes
return payload


class InsecureServer(MockExternalHTTPServer):
Expand Down Expand Up @@ -158,28 +168,23 @@ def test_http_no_payload(server, method):
status, data = client.send_request(method=method, headers={"foo": "bar"})

assert status == 200
data = ensure_str(data)
data = data.split("\n")
request_line, headers, _payload = decode_payload(data)

# Verify connection has been closed
assert client._connection_attr is None
assert connection.pool is None

# Verify request line
assert data[0].startswith(f"{method} /agent_listener/invoke_raw_method ")
assert request_line.startswith(f"{method} /agent_listener/invoke_raw_method ")

# Verify headers
user_agent_header = ""
foo_header = ""

for header in data[1:-1]:
if header.lower().startswith("user-agent:"):
_, value = header.split(":", 1)
value = value.strip()
for key, value in headers:
if key.lower() == "user-agent":
user_agent_header = value
elif header.startswith("foo:"):
_, value = header.split(":", 1)
value = value.strip()
if key.lower() == "foo":
foo_header = value

assert user_agent_header.startswith("NewRelic-PythonAgent/")
Expand Down Expand Up @@ -228,7 +233,6 @@ def test_http_close_connection_in_context_manager():
client.close_connection()


@FAILING_ON_WINDOWS
@pytest.mark.parametrize(
"client_cls,method,threshold",
(
Expand Down Expand Up @@ -267,8 +271,7 @@ def test_http_payload_compression(server, client_cls, method, threshold):
status, data = client.send_request(payload=payload, params={"method": "method2"})

assert status == 200
data = data.split(b"\n")
sent_payload = data[-1]
_request_line, headers, sent_payload = decode_payload(data)
payload_byte_len = len(sent_payload)
internal_metrics = dict(internal_metrics.metrics())
if client_cls is ApplicationModeClient:
Expand All @@ -292,7 +295,7 @@ def test_http_payload_compression(server, client_cls, method, threshold):
assert not internal_metrics

if threshold < 20:
expected_content_encoding = method.encode("utf-8")
expected_content_encoding = method
assert sent_payload != payload
if method == "deflate":
sent_payload = zlib.decompress(sent_payload)
Expand All @@ -301,12 +304,11 @@ def test_http_payload_compression(server, client_cls, method, threshold):
sent_payload = decompressor.decompress(sent_payload)
sent_payload += decompressor.flush()
else:
expected_content_encoding = b"Identity"
expected_content_encoding = "Identity"

for header in data[1:-1]:
if header.lower().startswith(b"content-encoding"):
_, content_encoding = header.split(b":", 1)
content_encoding = content_encoding.strip()
for key, value in headers:
if key.lower() == "content-encoding":
content_encoding = value
break
else:
raise AssertionError("Missing content-encoding header")
Expand Down Expand Up @@ -375,16 +377,13 @@ def test_ssl_via_ssl_proxy(server, auth):
status, data = client.send_request()

assert status == 200
data = data.decode("utf-8")
data = data.split("\n")
assert data[0].startswith("POST https://localhost:1/agent_listener/invoke_raw_method ")
request_line, headers, _payload = decode_payload(data)
assert request_line.startswith("POST https://localhost:1/agent_listener/invoke_raw_method ")

proxy_auth = None
for header in data[1:-1]:
if header.lower().startswith("proxy-authorization"):
_, proxy_auth = header.split(":", 1)
proxy_auth = proxy_auth.strip()
break
for key, value in headers:
if key.lower() == "proxy-authorization":
proxy_auth = value

if proxy_user:
auth_expected = proxy_user
Expand All @@ -410,9 +409,8 @@ def test_non_ssl_via_ssl_proxy(server):
status, data = client.send_request()

assert status == 200
data = data.decode("utf-8")
data = data.split("\n")
assert data[0].startswith("POST http://localhost:1/agent_listener/invoke_raw_method ")
request_line, _headers, _payload = decode_payload(data)
assert request_line.startswith("POST http://localhost:1/agent_listener/invoke_raw_method ")

assert server.httpd.connect_host is None

Expand All @@ -424,9 +422,8 @@ def test_non_ssl_via_non_ssl_proxy(insecure_server):
status, data = client.send_request()

assert status == 200
data = data.decode("utf-8")
data = data.split("\n")
assert data[0].startswith("POST http://localhost:1/agent_listener/invoke_raw_method ")
request_line, _headers, _payload = decode_payload(data)
assert request_line.startswith("POST http://localhost:1/agent_listener/invoke_raw_method ")

assert insecure_server.httpd.connect_host is None

Expand Down