Skip to content

Add a --unix-socket-path flag #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 7, 2021
14 changes: 8 additions & 6 deletions examples/pubsub_eventing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
process_publisher_request_id = '12345'
num_events_received = [0, 0]

logger = logging.getLogger(__name__)


def publisher_process(
shutdown_event: multiprocessing.synchronize.Event,
dispatcher_queue: EventQueue,
) -> None:
print('publisher starting')
logger.info('publisher starting')
try:
while not shutdown_event.is_set():
dispatcher_queue.publish(
Expand All @@ -40,7 +42,7 @@ def publisher_process(
)
except KeyboardInterrupt:
pass
print('publisher shutdown')
logger.info('publisher shutdown')


def on_event(payload: Dict[str, Any]) -> None:
Expand All @@ -50,7 +52,6 @@ def on_event(payload: Dict[str, Any]) -> None:
num_events_received[0] += 1
else:
num_events_received[1] += 1
# print(payload)


if __name__ == '__main__':
Expand Down Expand Up @@ -86,7 +87,7 @@ def on_event(payload: Dict[str, Any]) -> None:
publisher_id='eventing_pubsub_main',
)
except KeyboardInterrupt:
print('bye!!!')
logger.info('bye!!!')
finally:
# Stop publisher
publisher_shutdown_event.set()
Expand All @@ -95,8 +96,9 @@ def on_event(payload: Dict[str, Any]) -> None:
subscriber.unsubscribe()
# Signal dispatcher to shutdown
event_manager.stop_event_dispatcher()
print(
logger.info(
'Received {0} events from main thread, {1} events from another process, in {2} seconds'.format(
num_events_received[0], num_events_received[1], time.time() - start_time,
num_events_received[0], num_events_received[1], time.time(
) - start_time,
),
)
6 changes: 5 additions & 1 deletion examples/ssl_echo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import logging

from proxy.core.connection import TcpServerConnection
from proxy.common.constants import DEFAULT_BUFFER_SIZE

logger = logging.getLogger(__name__)

if __name__ == '__main__':
client = TcpServerConnection('::', 12345)
client.connect()
Expand All @@ -24,6 +28,6 @@
data = client.recv(DEFAULT_BUFFER_SIZE)
if data is None:
break
print(data.tobytes())
logger.info(data.tobytes())
finally:
client.close()
6 changes: 5 additions & 1 deletion examples/tcp_echo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import logging

from proxy.common.utils import socket_connection
from proxy.common.constants import DEFAULT_BUFFER_SIZE

logger = logging.getLogger(__name__)

if __name__ == '__main__':
with socket_connection(('::', 12345)) as client:
while True:
client.send(b'hello')
data = client.recv(DEFAULT_BUFFER_SIZE)
if data is None:
break
print(data)
logger.info(data)
7 changes: 5 additions & 2 deletions examples/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
:license: BSD, see LICENSE for more details.
"""
import time
from proxy.http.websocket import WebsocketClient, WebsocketFrame, websocketOpcodes
import logging

from proxy.http.websocket import WebsocketClient, WebsocketFrame, websocketOpcodes

# globals
client: WebsocketClient
last_dispatch_time: float
static_frame = memoryview(WebsocketFrame.text(b'hello'))
num_echos = 10

logger = logging.getLogger(__name__)


def on_message(frame: WebsocketFrame) -> None:
"""WebsocketClient on_message callback."""
global client, num_echos, last_dispatch_time
print(
logger.info(
'Received %r after %d millisec' %
(frame.data, (time.time() - last_dispatch_time) * 1000),
)
Expand Down
13 changes: 12 additions & 1 deletion proxy/common/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,18 @@ def initialize(
IpAddress,
opts.get('hostname', ipaddress.ip_address(args.hostname)),
)
args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
# AF_UNIX is not available on Windows
# See https://bugs.python.org/issue33408
if os.name != 'nt':
args.family = socket.AF_UNIX if args.unix_socket_path else (
socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
)
else:
# FIXME: Not true for tests, as this value will be mock
# It's a problem only on Windows. Instead of a proper
# test level fix, simply commenting this for now.
# assert args.unix_socket_path is None
args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET
args.port = cast(int, opts.get('port', args.port))
args.backlog = cast(int, opts.get('backlog', args.backlog))
num_workers = opts.get('num_workers', args.num_workers)
Expand Down
11 changes: 8 additions & 3 deletions proxy/common/pki.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,20 @@ def run_openssl_command(command: List[str], timeout: int) -> bool:

# Validation
if args.action not in available_actions:
print('Invalid --action. Valid values ' + ', '.join(available_actions))
logger.error(
'Invalid --action. Valid values ' +
', '.join(available_actions),
)
sys.exit(1)
if args.action in ('gen_private_key', 'gen_public_key'):
if args.private_key_path is None:
print('--private-key-path is required for ' + args.action)
logger.error('--private-key-path is required for ' + args.action)
sys.exit(1)
if args.action == 'gen_public_key':
if args.public_key_path is None:
print('--public-key-file is required for private key generation')
logger.error(
'--public-key-file is required for private key generation',
)
sys.exit(1)

# Execute
Expand Down
10 changes: 7 additions & 3 deletions proxy/core/acceptor/acceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,20 @@ def shutdown_threadless_process(self) -> None:
self.threadless_process.join()
self.threadless_client_queue.close()

def _start_threadless_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
def _start_threadless_work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None:
assert self.threadless_process and self.threadless_client_queue
self.threadless_client_queue.send(addr)
# Accepted client address is empty string for
# unix socket domain, avoid sending empty string
if not self.flags.unix_socket_path:
self.threadless_client_queue.send(addr)
send_handle(
self.threadless_client_queue,
conn.fileno(),
self.threadless_process.pid,
)
conn.close()

def _start_threaded_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
def _start_threaded_work(self, conn: socket.socket, addr: Optional[Tuple[str, int]]) -> None:
work = self.work_klass(
TcpClientConnection(conn, addr),
flags=self.flags,
Expand All @@ -145,6 +148,7 @@ def run_once(self) -> None:
if len(events) == 0:
return
conn, addr = self.sock.accept()
addr = None if addr == '' else addr
if (
self.flags.threadless and
self.threadless_client_queue and
Expand Down
27 changes: 24 additions & 3 deletions proxy/core/acceptor/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import os
import argparse
import logging
import multiprocessing
Expand Down Expand Up @@ -61,6 +62,14 @@
help='Defaults to number of CPU cores.',
)

flags.add_argument(
'--unix-socket-path',
type=str,
default=None,
help='Default: None. Unix socket path to use. ' +
'When provided --host and --port flags are ignored',
)


class AcceptorPool:
"""AcceptorPool is a helper class which pre-spawns `Acceptor` processes
Expand Down Expand Up @@ -108,8 +117,11 @@ def __exit__(
self.shutdown()

def setup(self) -> None:
"""Listen on port and setup acceptors."""
self._listen()
"""Setup socket and acceptors."""
if self.flags.unix_socket_path:
self._listen_unix_socket()
else:
self._listen_server_port()
# Override flags.port to match the actual port
# we are listening upon. This is necessary to preserve
# the server port when `--port=0` is used.
Expand All @@ -133,9 +145,18 @@ def shutdown(self) -> None:
acceptor.running.set()
for acceptor in self.acceptors:
acceptor.join()
if self.flags.unix_socket_path:
os.remove(self.flags.unix_socket_path)
logger.debug('Acceptors shutdown')

def _listen(self) -> None:
def _listen_unix_socket(self) -> None:
self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.flags.unix_socket_path)
self.socket.listen(self.flags.backlog)
self.socket.setblocking(False)

def _listen_server_port(self) -> None:
self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((str(self.flags.hostname), self.flags.port))
Expand Down
6 changes: 5 additions & 1 deletion proxy/core/acceptor/threadless.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def fromfd(self, fileno: int) -> socket.socket:
)

def accept_client(self) -> None:
addr = self.client_queue.recv()
# Acceptor will not send address for
# unix socket domain environments.
addr = None
if not self.flags.unix_socket_path:
addr = self.client_queue.recv()
fileno = recv_handle(self.client_queue)
self.works[fileno] = self.work_klass(
TcpClientConnection(conn=self.fromfd(fileno), addr=addr),
Expand Down
19 changes: 13 additions & 6 deletions proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ class BaseTcpServerHandler(Work):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.must_flush_before_shutdown = False
logger.debug('Connection accepted from {0}'.format(self.work.addr))
if self.flags.unix_socket_path:
logger.debug(
'Connection accepted from {0}'.format(self.work.address),
)
else:
logger.debug(
'Connection accepted from {0}'.format(self.work.address),
)

@abstractmethod
def handle_data(self, data: memoryview) -> Optional[bool]:
Expand Down Expand Up @@ -79,7 +86,7 @@ def handle_events(
if teardown:
logger.debug(
'Shutting down client {0} connection'.format(
self.work.addr,
self.work.address,
),
)
return teardown
Expand All @@ -88,7 +95,7 @@ def handle_writables(self, writables: Writables) -> bool:
teardown = False
if self.work.connection in writables and self.work.has_buffer():
logger.debug(
'Flushing buffer to client {0}'.format(self.work.addr),
'Flushing buffer to client {0}'.format(self.work.address),
)
self.work.flush()
if self.must_flush_before_shutdown is True:
Expand All @@ -104,7 +111,7 @@ def handle_readables(self, readables: Readables) -> bool:
if data is None:
logger.debug(
'Connection closed by client {0}'.format(
self.work.addr,
self.work.address,
),
)
teardown = True
Expand All @@ -113,13 +120,13 @@ def handle_readables(self, readables: Readables) -> bool:
if isinstance(r, bool) and r is True:
logger.debug(
'Implementation signaled shutdown for client {0}'.format(
self.work.addr,
self.work.address,
),
)
if self.work.has_buffer():
logger.debug(
'Client {0} has pending buffer, will be flushed before shutting down'.format(
self.work.addr,
self.work.address,
),
)
self.must_flush_before_shutdown = True
Expand Down
9 changes: 6 additions & 3 deletions proxy/core/base/tcp_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
:license: BSD, see LICENSE for more details.
"""
import socket
import logging
import selectors

from abc import abstractmethod
Expand All @@ -21,6 +22,8 @@
from ..connection import TcpServerConnection
from .tcp_server import BaseTcpServerHandler

logger = logging.getLogger(__name__)


class BaseTcpTunnelHandler(BaseTcpServerHandler):
"""BaseTcpTunnelHandler build on-top of BaseTcpServerHandler work klass.
Expand All @@ -47,7 +50,7 @@ def initialize(self) -> None:

def shutdown(self) -> None:
if self.upstream:
print(
logger.debug(
'Connection closed with upstream {0}:{1}'.format(
text_(self.request.host), self.request.port,
),
Expand Down Expand Up @@ -84,7 +87,7 @@ def handle_events(
data = self.upstream.recv()
if data is None:
# Server closed connection
print('Connection closed by server')
logger.debug('Connection closed by server')
return True
# tunnel data to client
self.work.queue(data)
Expand All @@ -98,7 +101,7 @@ def connect_upstream(self) -> None:
text_(self.request.host), self.request.port,
)
self.upstream.connect()
print(
logger.debug(
'Connection established with upstream {0}:{1}'.format(
text_(self.request.host), self.request.port,
),
Expand Down
9 changes: 7 additions & 2 deletions proxy/core/connection/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ class TcpClientConnection(TcpConnection):
def __init__(
self,
conn: Union[ssl.SSLSocket, socket.socket],
addr: Tuple[str, int],
# optional for unix socket servers
addr: Optional[Tuple[str, int]] = None,
) -> None:
super().__init__(tcpConnectionTypes.CLIENT)
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
self.addr: Tuple[str, int] = addr
self.addr: Optional[Tuple[str, int]] = addr

@property
def address(self) -> str:
return 'unix:client' if not self.addr else '{0}:{1}'.format(self.addr[0], self.addr[1])

@property
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
Expand Down
Loading