Skip to content

Commit 9b2e7bb

Browse files
authored
Expose UpstreamConnectionPool to web & proxy plugins (#946)
* Expose conn pool to plugins * Fix reusable state handling * Separate `release` and `retain` methods * Fix conn pool tests * Fix tests
1 parent a2e1fc6 commit 9b2e7bb

File tree

7 files changed

+69
-42
lines changed

7 files changed

+69
-42
lines changed

proxy/core/connection/pool.py

+44-33
Original file line numberDiff line numberDiff line change
@@ -77,57 +77,50 @@ def __init__(self) -> None:
7777
self.connections: Dict[int, TcpServerConnection] = {}
7878
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}
7979

80-
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
81-
"""Creates and add a new connection to the pool."""
82-
new_conn = TcpServerConnection(addr[0], addr[1])
83-
new_conn.connect()
84-
self._add(new_conn)
85-
return new_conn
86-
8780
def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]:
8881
"""Returns a reusable connection from the pool.
8982
9083
If none exists, will create and return a new connection."""
84+
created, conn = False, None
9185
if addr in self.pools:
9286
for old_conn in self.pools[addr]:
9387
if old_conn.is_reusable():
94-
old_conn.mark_inuse()
88+
conn = old_conn
9589
logger.debug(
9690
'Reusing connection#{2} for upstream {0}:{1}'.format(
9791
addr[0], addr[1], id(old_conn),
9892
),
9993
)
100-
return False, old_conn
101-
new_conn = self.add(addr)
102-
logger.debug(
103-
'Created new connection#{2} for upstream {0}:{1}'.format(
104-
addr[0], addr[1], id(new_conn),
105-
),
106-
)
107-
return True, new_conn
94+
break
95+
if conn is None:
96+
created, conn = True, self.add(addr)
97+
conn.mark_inuse()
98+
return created, conn
10899

109100
def release(self, conn: TcpServerConnection) -> None:
110101
"""Release a previously acquired connection.
111102
112-
If the connection has not been closed,
113-
then it will be retained in the pool for reusability.
103+
Releasing a connection will shutdown and close the socket
104+
including internal pool cleanup.
114105
"""
115106
assert not conn.is_reusable()
116-
if conn.closed:
117-
logger.debug(
118-
'Removing connection#{2} from pool from upstream {0}:{1}'.format(
119-
conn.addr[0], conn.addr[1], id(conn),
120-
),
121-
)
122-
self._remove(conn.connection.fileno())
123-
else:
124-
logger.debug(
125-
'Retaining connection#{2} to upstream {0}:{1}'.format(
126-
conn.addr[0], conn.addr[1], id(conn),
127-
),
128-
)
129-
# Reset for reusability
130-
conn.reset()
107+
logger.debug(
108+
'Removing connection#{2} from pool from upstream {0}:{1}'.format(
109+
conn.addr[0], conn.addr[1], id(conn),
110+
),
111+
)
112+
self._remove(conn.connection.fileno())
113+
114+
def retain(self, conn: TcpServerConnection) -> None:
115+
"""Retained previously acquired connection in the pool for reusability."""
116+
assert not conn.closed
117+
logger.debug(
118+
'Retaining connection#{2} to upstream {0}:{1}'.format(
119+
conn.addr[0], conn.addr[1], id(conn),
120+
),
121+
)
122+
# Reset for reusability
123+
conn.reset()
131124

132125
async def get_events(self) -> SelectableEvents:
133126
"""Returns read event flag for all reusable connections in the pool."""
@@ -152,10 +145,28 @@ async def handle_events(self, readables: Readables, _writables: Writables) -> bo
152145
self._remove(fileno)
153146
return False
154147

148+
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
149+
"""Creates, connects and adds a new connection to the pool.
150+
151+
Returns newly created connection.
152+
153+
NOTE: You must not use the returned connection, instead use `acquire`.
154+
"""
155+
new_conn = TcpServerConnection(addr[0], addr[1])
156+
new_conn.connect()
157+
self._add(new_conn)
158+
logger.debug(
159+
'Created new connection#{2} for upstream {0}:{1}'.format(
160+
addr[0], addr[1], id(new_conn),
161+
),
162+
)
163+
return new_conn
164+
155165
def _add(self, conn: TcpServerConnection) -> None:
156166
"""Adds a new connection to internal data structure."""
157167
if conn.addr not in self.pools:
158168
self.pools[conn.addr] = set()
169+
conn._reusable = True
159170
self.pools[conn.addr].add(conn)
160171
self.connections[conn.connection.fileno()] = conn
161172

proxy/http/plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
flags: argparse.Namespace,
5353
client: TcpClientConnection,
5454
request: HttpParser,
55-
event_queue: Optional[EventQueue],
55+
event_queue: Optional[EventQueue] = None,
5656
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
5757
):
5858
self.uid: str = uid

proxy/http/proxy/plugin.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
import argparse
1212

1313
from abc import ABC
14-
from typing import Any, Dict, Optional, Tuple
14+
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
1515

1616
from ..parser import HttpParser
1717
from ..descriptors import DescriptorsHandlerMixin
1818

1919
from ...core.event import EventQueue
2020
from ...core.connection import TcpClientConnection
2121

22+
if TYPE_CHECKING:
23+
from ...core.connection import UpstreamConnectionPool
24+
2225

2326
class HttpProxyBasePlugin(DescriptorsHandlerMixin, ABC):
2427
"""Base HttpProxyPlugin Plugin class.
@@ -31,11 +34,13 @@ def __init__(
3134
flags: argparse.Namespace,
3235
client: TcpClientConnection,
3336
event_queue: EventQueue,
37+
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
3438
) -> None:
3539
self.uid = uid # pragma: no cover
3640
self.flags = flags # pragma: no cover
3741
self.client = client # pragma: no cover
3842
self.event_queue = event_queue # pragma: no cover
43+
self.upstream_conn_pool = upstream_conn_pool
3944

4045
def name(self) -> str:
4146
"""A unique name for your plugin.

proxy/http/proxy/server.py

+1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
self.flags,
160160
self.client,
161161
self.event_queue,
162+
self.upstream_conn_pool,
162163
)
163164
self.plugins[instance.name()] = instance
164165

proxy/http/server/plugin.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import argparse
1212

1313
from abc import ABC, abstractmethod
14-
from typing import Any, Dict, List, Optional, Tuple
14+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
1515

1616
from ..websocket import WebsocketFrame
1717
from ..parser import HttpParser
@@ -20,6 +20,9 @@
2020
from ...core.connection import TcpClientConnection
2121
from ...core.event import EventQueue
2222

23+
if TYPE_CHECKING:
24+
from ...core.connection import UpstreamConnectionPool
25+
2326

2427
class HttpWebServerBasePlugin(DescriptorsHandlerMixin, ABC):
2528
"""Web Server Plugin for routing of requests."""
@@ -30,11 +33,13 @@ def __init__(
3033
flags: argparse.Namespace,
3134
client: TcpClientConnection,
3235
event_queue: EventQueue,
36+
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
3337
):
3438
self.uid = uid
3539
self.flags = flags
3640
self.client = client
3741
self.event_queue = event_queue
42+
self.upstream_conn_pool = upstream_conn_pool
3843

3944
def name(self) -> str:
4045
"""A unique name for your plugin.

proxy/http/server/web.py

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _initialize_web_plugins(self) -> None:
106106
self.flags,
107107
self.client,
108108
self.event_queue,
109+
self.upstream_conn_pool,
109110
)
110111
self.plugins[instance.name()] = instance
111112
for (protocol, route) in instance.routes():

tests/core/test_conn_pool.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,34 @@
2121
class TestConnectionPool(unittest.TestCase):
2222

2323
@mock.patch('proxy.core.connection.pool.TcpServerConnection')
24-
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
24+
def test_acquire_and_retain_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
2525
pool = UpstreamConnectionPool()
2626
# Mock
2727
mock_conn = mock_tcp_server_connection.return_value
2828
addr = mock_conn.addr
2929
mock_conn.is_reusable.side_effect = [
30-
False, True, True,
30+
True,
3131
]
3232
mock_conn.closed = False
3333
# Acquire
3434
created, conn = pool.acquire(addr)
35-
self.assertTrue(created)
3635
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
36+
mock_conn.mark_inuse.assert_called_once()
37+
mock_conn.reset.assert_not_called()
38+
self.assertTrue(created)
3739
self.assertEqual(conn, mock_conn)
3840
self.assertEqual(len(pool.pools[addr]), 1)
3941
self.assertTrue(conn in pool.pools[addr])
40-
# Release (connection must be retained because not closed)
41-
pool.release(conn)
42+
self.assertEqual(len(pool.connections), 1)
43+
self.assertEqual(pool.connections[conn.connection.fileno()], mock_conn)
44+
# Retail
45+
pool.retain(conn)
4246
self.assertEqual(len(pool.pools[addr]), 1)
4347
self.assertTrue(conn in pool.pools[addr])
48+
mock_conn.reset.assert_called_once()
4449
# Reacquire
4550
created, conn = pool.acquire(addr)
4651
self.assertFalse(created)
47-
mock_conn.reset.assert_called_once()
4852
self.assertEqual(conn, mock_conn)
4953
self.assertEqual(len(pool.pools[addr]), 1)
5054
self.assertTrue(conn in pool.pools[addr])

0 commit comments

Comments
 (0)