Skip to content

[CacheByContentTypePlugin] Prepare for content type parsing #1038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 0 additions & 70 deletions helper/benchmark.sh

This file was deleted.

3 changes: 3 additions & 0 deletions proxy/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def _env_threadless_compliant() -> bool:
DEFAULT_DEVTOOLS_LOADER_ID = secrets.token_hex(8)

DEFAULT_DATA_DIRECTORY_PATH = os.path.join(str(pathlib.Path.home()), '.proxy')
DEFAULT_CACHE_DIRECTORY_PATH = os.path.join(
DEFAULT_DATA_DIRECTORY_PATH, 'cache',
)

# Cor plugins enabled by default or via flags
DEFAULT_ABC_PLUGINS = [
Expand Down
5 changes: 5 additions & 0 deletions proxy/common/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ def initialize(
)
os.makedirs(args.ca_cert_dir, exist_ok=True)

# FIXME: Necessary here until flags framework provides a way
# for flag owners to initialize
os.makedirs(args.cache_dir, exist_ok=True)
os.makedirs(os.path.join(args.cache_dir, 'response'), exist_ok=True)

return args

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions proxy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import socket
import logging
import argparse
import functools
import ipaddress
import contextlib
Expand All @@ -34,6 +35,13 @@
logger = logging.getLogger(__name__)


def tls_interception_enabled(flags: argparse.Namespace) -> bool:
return flags.ca_key_file is not None and \
flags.ca_cert_dir is not None and \
flags.ca_signing_key_file is not None and \
flags.ca_cert_file is not None


def is_threadless(threadless: bool, threaded: bool) -> bool:
# if default is threadless then return true unless
# user has overridden mode using threaded flag.
Expand Down
6 changes: 0 additions & 6 deletions proxy/http/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from typing import Any

from ..common.types import Readables, Writables, Descriptors


Expand All @@ -19,10 +17,6 @@ class DescriptorsHandlerMixin:
include web and proxy plugins. By using DescriptorsHandlerMixin, class
becomes complaint with core event loop."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
# FIXME: Required for multi-level inheritance to work
super().__init__(*args, **kwargs) # type: ignore

# @abstractmethod
async def get_descriptors(self) -> Descriptors:
"""Implementations must return a list of descriptions that they wish to
Expand Down
30 changes: 0 additions & 30 deletions proxy/http/mixins.py

This file was deleted.

8 changes: 5 additions & 3 deletions proxy/http/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Union, Optional

from .mixins import TlsInterceptionPropertyMixin
from .parser import HttpParser
from .connection import HttpClientConnection
from ..core.event import EventQueue
from .descriptors import DescriptorsHandlerMixin
from ..common.utils import tls_interception_enabled


if TYPE_CHECKING: # pragma: no cover
Expand All @@ -26,7 +26,6 @@

class HttpProtocolHandlerPlugin(
DescriptorsHandlerMixin,
TlsInterceptionPropertyMixin,
ABC,
):
"""Base HttpProtocolHandler Plugin class.
Expand Down Expand Up @@ -59,7 +58,6 @@ def __init__(
event_queue: Optional[EventQueue] = None,
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
):
super().__init__(uid, flags, client, event_queue, upstream_conn_pool)
self.uid: str = uid
self.flags: argparse.Namespace = flags
self.client: HttpClientConnection = client
Expand Down Expand Up @@ -95,3 +93,7 @@ def on_client_connection_close(self) -> None:
perform any cleanup work here.
"""
pass # pragma: no cover

@property
def tls_interception_enabled(self) -> bool:
return tls_interception_enabled(self.flags)
6 changes: 2 additions & 4 deletions proxy/http/proxy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, Tuple, Optional

from ..mixins import TlsInterceptionPropertyMixin
from ..parser import HttpParser
from ..connection import HttpClientConnection
from ...core.event import EventQueue
from ..descriptors import DescriptorsHandlerMixin
from ...common.utils import tls_interception_enabled


if TYPE_CHECKING: # pragma: no cover
Expand All @@ -25,7 +25,6 @@

class HttpProxyBasePlugin(
DescriptorsHandlerMixin,
TlsInterceptionPropertyMixin,
ABC,
):
"""Base HttpProxyPlugin Plugin class.
Expand All @@ -40,7 +39,6 @@ def __init__(
event_queue: EventQueue,
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
) -> None:
super().__init__(uid, flags, client, event_queue, upstream_conn_pool)
self.uid = uid # pragma: no cover
self.flags = flags # pragma: no cover
self.client = client # pragma: no cover
Expand Down Expand Up @@ -170,4 +168,4 @@ def do_intercept(self, _request: HttpParser) -> bool:
flags BUT only conditionally enable interception for
certain requests.
"""
return self.tls_interception_enabled
return tls_interception_enabled(self.flags)
2 changes: 2 additions & 0 deletions proxy/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
from .custom_dns_resolver import CustomDnsResolverPlugin
from .filter_by_client_ip import FilterByClientIpPlugin
from .filter_by_url_regex import FilterByURLRegexPlugin
from .cache_by_content_type import CacheByContentTypePlugin
from .modify_chunk_response import ModifyChunkResponsePlugin
from .redirect_to_custom_server import RedirectToCustomServerPlugin


__all__ = [
'CacheResponsesPlugin',
'CacheByContentTypePlugin',
'BaseCacheResponsesPlugin',
'FilterByUpstreamHostPlugin',
'ManInTheMiddlePlugin',
Expand Down
15 changes: 13 additions & 2 deletions proxy/plugin/cache/cache_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import os
import multiprocessing
from typing import Any
from typing import Any, Dict, Optional

from .base import BaseCacheResponsesPlugin
from .store.disk import OnDiskCacheStore
Expand All @@ -24,6 +25,16 @@ class CacheResponsesPlugin(BaseCacheResponsesPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.disk_store = OnDiskCacheStore(
uid=self.uid, cache_dir=self.flags.cache_dir,
uid=self.uid,
cache_dir=os.path.join(
self.flags.cache_dir,
'responses',
),
)
self.set_store(self.disk_store)

def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
context.update({
'cache_file_path': self.disk_store.cache_file_path,
})
return super().on_access_log(context)
6 changes: 3 additions & 3 deletions proxy/plugin/cache/store/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
"""
import os
import logging
import tempfile
from typing import BinaryIO, Optional

from .base import CacheStore
from ....common.flag import flags
from ....http.parser import HttpParser
from ....common.utils import text_
from ....common.constants import DEFAULT_CACHE_DIRECTORY_PATH


logger = logging.getLogger(__name__)
Expand All @@ -25,8 +25,8 @@
flags.add_argument(
'--cache-dir',
type=str,
default=tempfile.gettempdir(),
help='Default: A temporary directory. ' +
default=DEFAULT_CACHE_DIRECTORY_PATH,
help='Default: ' + DEFAULT_CACHE_DIRECTORY_PATH + '. ' +
'Flag only applicable when cache plugin is used with on-disk storage.',
)

Expand Down
32 changes: 32 additions & 0 deletions proxy/plugin/cache_by_content_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.

:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import logging
from typing import Any, Dict, Optional

from ..http.proxy import HttpProxyBasePlugin
from ..common.utils import tls_interception_enabled


logger = logging.getLogger(__name__)


class CacheByContentTypePlugin(HttpProxyBasePlugin):
"""This plugin is supposed to work with
:py:class:`~proxy.plugin.cache.cache_responses.CacheResponsesPlugin`. This plugin
must be put after the cache response plugin in the chain.

Plugin will try to extract out content type from the responses.
When found, data is stored under ``proxy.py`` instance data directory."""

def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if tls_interception_enabled(self.flags) and 'cache_file_path' in context:
print('cache file found')
return super().on_access_log(context)
6 changes: 4 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
DEFAULT_ENABLE_SSH_TUNNEL, DEFAULT_ENABLE_WEB_SERVER,
DEFAULT_DISABLE_HTTP_PROXY, PLUGIN_WEBSOCKET_TRANSPORT,
DEFAULT_CA_SIGNING_KEY_FILE, DEFAULT_CLIENT_RECVBUF_SIZE,
DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_ENABLE_REVERSE_PROXY,
DEFAULT_ENABLE_STATIC_SERVER, _env_threadless_compliant,
DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CACHE_DIRECTORY_PATH,
DEFAULT_ENABLE_REVERSE_PROXY, DEFAULT_ENABLE_STATIC_SERVER,
_env_threadless_compliant,
)


Expand Down Expand Up @@ -79,6 +80,7 @@ def mock_default_args(mock_args: mock.Mock) -> None:
mock_args.enable_ssh_tunnel = DEFAULT_ENABLE_SSH_TUNNEL
mock_args.enable_reverse_proxy = DEFAULT_ENABLE_REVERSE_PROXY
mock_args.unix_socket_path = None
mock_args.cache_dir = DEFAULT_CACHE_DIRECTORY_PATH

@mock.patch('os.remove')
@mock.patch('os.path.exists')
Expand Down