Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions responses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _get_url_and_path(url: str) -> str:


def _handle_body(
body: Optional[Union[bytes, BufferedReader, str]]
body: Optional[Union[bytes, BufferedReader, str]],
) -> Union[BufferedReader, BytesIO]:
"""Generates `Response` body.

Expand Down Expand Up @@ -1003,7 +1003,7 @@ def activate(self, func: "_F" = ...) -> "_F":
"""Overload for scenario when 'responses.activate' is used."""

@overload
def activate( # type: ignore[misc]
def activate(
self,
*,
registry: Type[Any] = ...,
Expand Down Expand Up @@ -1096,9 +1096,11 @@ def _on_request(
if match is None:
if any(
[
p.match(request_url)
if isinstance(p, Pattern)
else request_url.startswith(p)
(
p.match(request_url)
if isinstance(p, Pattern)
else request_url.startswith(p)
)
for p in self.passthru_prefixes
]
):
Expand Down
38 changes: 29 additions & 9 deletions responses/_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,35 @@ def _remove_nones(d: "Any") -> "Any":
return d


def _remove_default_headers(data: "Any") -> "Any":
def _remove_default_headers(
data: "Any", additional_headers: "Optional[List[str]]" = None
) -> "Any":
"""
It would be too verbose to store these headers in the file generated by the
record functionality.
record functionality. If additional_headers is provided, those headers
will be preserved in addition to the normal behavior.
"""
if isinstance(data, dict):
keys_to_remove = [
default_keys_to_remove = [
"Content-Length",
"Content-Type",
"Date",
"Server",
"Connection",
"Content-Encoding",
]

for i, response in enumerate(data["responses"]):
for key in keys_to_remove:
if key in response["response"]["headers"]:
# Remove default headers as before, but preserve additional headers
keys_to_preserve = set(additional_headers) if additional_headers else set()

for key in default_keys_to_remove:
if (
key in response["response"]["headers"]
and key not in keys_to_preserve
):
del data["responses"][i]["response"]["headers"][key]

if not response["response"]["headers"]:
del data["responses"][i]["response"]["headers"]
return data
Expand All @@ -64,6 +75,7 @@ def _dump(
registered: "List[BaseResponse]",
destination: "Union[BinaryIO, TextIOWrapper]",
dumper: "Callable[[Union[Dict[Any, Any], List[Any]], Union[BinaryIO, TextIOWrapper]], Any]",
additional_headers: "Optional[List[str]]" = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be include_headers or capture_headers? On first read my interpretation was that these headers would be added into the recorded responses.

) -> None:
data: Dict[str, Any] = {"responses": []}
for rsp in registered:
Expand All @@ -88,7 +100,9 @@ def _dump(
"Probably you use custom Response object that is missing required attributes"
) from exc

dumper(_remove_default_headers(_remove_nones(data)), destination)
dumper(
_remove_default_headers(_remove_nones(data), additional_headers), destination
)


class Recorder(RequestsMock):
Expand All @@ -104,15 +118,20 @@ def reset(self) -> None:
self._registry = OrderedRegistry()

def record(
self, *, file_path: "Union[str, bytes, os.PathLike[Any]]" = "response.yaml"
self,
*,
file_path: "Union[str, bytes, os.PathLike[Any]]" = "response.yaml",
additional_headers: "Optional[List[str]]" = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with include_headers, capture_headers or retain_headers 🤷

) -> "Union[Callable[[_F], _F], _F]":
def deco_record(function: "_F") -> "Callable[..., Any]":
@wraps(function)
def wrapper(*args: "Any", **kwargs: "Any") -> "Any": # type: ignore[misc]
with self:
ret = function(*args, **kwargs)
self.dump_to_file(
file_path=file_path, registered=self.get_registry().registered
file_path=file_path,
registered=self.get_registry().registered,
additional_headers=additional_headers,
)

return ret
Expand All @@ -126,12 +145,13 @@ def dump_to_file(
file_path: "Union[str, bytes, os.PathLike[Any]]",
*,
registered: "Optional[List[BaseResponse]]" = None,
additional_headers: "Optional[List[str]]" = None,
) -> None:
"""Dump the recorded responses to a file."""
if registered is None:
registered = self.get_registry().registered
with open(file_path, "w") as file:
_dump(registered, file, yaml.dump)
_dump(registered, file, yaml.dump, additional_headers)

def _on_request(
self,
Expand Down
2 changes: 1 addition & 1 deletion responses/tests/test_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def run():

class TestHeaderWithRegex:
@property
def url(self): # type: ignore[misc]
def url(self):
return "http://example.com/"

def _register(self):
Expand Down
158 changes: 151 additions & 7 deletions responses/tests/test_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_data(host, port):
{
"response": {
"method": "GET",
"url": f"http://{host}:{port}/404",
"url": f"http://{host}:{port}/404", # noqa: E231
"headers": {"x": "foo"},
"body": "404 Not Found",
"status": 404,
Expand All @@ -33,7 +33,7 @@ def get_data(host, port):
{
"response": {
"method": "GET",
"url": f"http://{host}:{port}/status/wrong",
"url": f"http://{host}:{port}/status/wrong", # noqa: E231
"headers": {"x": "foo"},
"body": "Invalid status code",
"status": 400,
Expand All @@ -44,7 +44,7 @@ def get_data(host, port):
{
"response": {
"method": "GET",
"url": f"http://{host}:{port}/500",
"url": f"http://{host}:{port}/500", # noqa: E231
"headers": {"x": "foo"},
"body": "500 Internal Server Error",
"status": 500,
Expand All @@ -55,7 +55,7 @@ def get_data(host, port):
{
"response": {
"method": "PUT",
"url": f"http://{host}:{port}/202",
"url": f"http://{host}:{port}/202", # noqa: E231
"body": "OK",
"status": 202,
"content_type": "text/plain",
Expand Down Expand Up @@ -97,11 +97,13 @@ def run():
def test_recorder_toml(self, httpserver):
custom_recorder = _recorder.Recorder()

def dump_to_file(file_path, registered):
def dump_to_file(file_path, registered=None, additional_headers=None):
if registered is None:
registered = custom_recorder.get_registry().registered
with open(file_path, "wb") as file:
_dump(registered, file, tomli_w.dump) # type: ignore[arg-type]
_dump(registered, file, tomli_w.dump, additional_headers) # type: ignore[arg-type]

custom_recorder.dump_to_file = dump_to_file # type: ignore[assignment]
custom_recorder.dump_to_file = dump_to_file # type: ignore[method-assign]

url202, url400, url404, url500 = self.prepare_server(httpserver)

Expand Down Expand Up @@ -238,3 +240,145 @@ def _parse_resp_f(file_path):
assert responses.registered()[3].content_type == "text/plain"

run()


class TestRecorderAdditionalHeaders:
def setup_method(self):
self.out_file = Path("response_record_headers")
if self.out_file.exists():
self.out_file.unlink()
assert not self.out_file.exists()

def teardown_method(self):
if self.out_file.exists():
self.out_file.unlink()

def prepare_server_with_headers(self, httpserver):
httpserver.expect_request("/test").respond_with_data(
"Test Response",
status=200,
content_type="text/plain",
headers={
"Content-Length": "13",
"Server": "nginx/1.0",
"Connection": "keep-alive",
"Content-Encoding": "identity",
"Authorization": "Bearer token123",
"X-Custom-Header": "custom-value",
"User-Agent": "test-agent",
},
)
return httpserver.url_for("/test")

def test_recorder_with_additional_headers(self, httpserver):
url = self.prepare_server_with_headers(httpserver)

@_recorder.record(
file_path=self.out_file,
additional_headers=["Authorization", "X-Custom-Header", "Date"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
additional_headers=["Authorization", "X-Custom-Header", "Date"],
include_headers=["Authorization", "X-Custom-Header", "Date"],

I think this reads better, what do you think?

)
def run():
requests.get(url)

run()

with open(self.out_file) as file:
data = yaml.safe_load(file)

response_headers = data["responses"][0]["response"]["headers"]

# Additional headers should be preserved
assert "Authorization" in response_headers
assert response_headers["Authorization"] == "Bearer token123"
assert "X-Custom-Header" in response_headers
assert response_headers["X-Custom-Header"] == "custom-value"

# Default headers should still be removed (not in additional_headers)
assert "Content-Length" not in response_headers
assert "Server" not in response_headers
assert "Connection" not in response_headers
assert "Content-Encoding" not in response_headers

# Other headers not in default removal list should remain
assert "User-Agent" in response_headers

def test_recorder_with_additional_headers_preserves_default_removal(
self, httpserver
):
url = self.prepare_server_with_headers(httpserver)

@_recorder.record(
file_path=self.out_file, additional_headers=["Content-Type", "Server"]
)
def run():
requests.get(url)

run()

with open(self.out_file) as file:
data = yaml.safe_load(file)

response_headers = data["responses"][0]["response"]["headers"]

# Headers in additional_headers should be preserved even if normally removed
assert "Content-Type" in response_headers
assert response_headers["Content-Type"] == "text/plain"
assert "Server" in response_headers
assert "nginx/1.0" in response_headers["Server"]

# Other default headers should still be removed
assert "Content-Length" not in response_headers
assert "Connection" not in response_headers
assert "Content-Encoding" not in response_headers

def test_recorder_without_additional_headers_default_behavior(self, httpserver):
url = self.prepare_server_with_headers(httpserver)

@_recorder.record(file_path=self.out_file)
def run():
requests.get(url)

run()

with open(self.out_file) as file:
data = yaml.safe_load(file)

response_headers = data["responses"][0]["response"]["headers"]

# Default headers should be removed
assert "Content-Length" not in response_headers
assert "Content-Type" not in response_headers
assert "Server" not in response_headers
assert "Connection" not in response_headers
assert "Content-Encoding" not in response_headers

# Non-default headers should remain
assert "Authorization" in response_headers
assert "X-Custom-Header" in response_headers
assert "User-Agent" in response_headers

def test_dump_to_file_with_additional_headers(self, httpserver):
url = self.prepare_server_with_headers(httpserver)

_recorder.recorder.start()
requests.get(url)
_recorder.recorder.stop()

_recorder.recorder.dump_to_file(
self.out_file, additional_headers=["Content-Length"]
)

with open(self.out_file) as file:
data = yaml.safe_load(file)

response_headers = data["responses"][0]["response"]["headers"]

# Additional headers should be preserved
assert "Content-Length" in response_headers

# Other default headers should be removed
assert "Server" not in response_headers
assert "Connection" not in response_headers
assert "Content-Encoding" not in response_headers

_recorder.recorder.reset()
8 changes: 4 additions & 4 deletions responses/tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,12 +926,12 @@ def test_function(a, b=None):


@pytest.fixture
def my_fruit(): # type: ignore[misc]
def my_fruit():
return "apple"


@pytest.fixture
def fruit_basket(my_fruit): # type: ignore[misc]
def fruit_basket(my_fruit):
return ["banana", my_fruit]


Expand Down Expand Up @@ -1333,7 +1333,7 @@ def test_handles_buffered_reader_body():

@responses.activate
def run():
responses.add(responses.GET, url, body=BufferedReader(BytesIO(b"test"))) # type: ignore
responses.add(responses.GET, url, body=BufferedReader(BytesIO(b"test")))

resp = requests.get(url)

Expand Down Expand Up @@ -1558,7 +1558,7 @@ def run():
responses.add(
responses.GET,
url,
body=BufferedReader(BytesIO(b"testing")), # type: ignore
body=BufferedReader(BytesIO(b"testing")),
auto_calculate_content_length=True,
)
resp = requests.get(url)
Expand Down
Loading