Skip to content

Add support for ephemeral services. #1302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions cms/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import errno
import ipaddress
import json
import logging
import os
import socket
import sys
from collections import namedtuple
from contextlib import closing

from .log import set_detailed_logs

Expand All @@ -44,6 +47,7 @@ class ServiceCoord(namedtuple("ServiceCoord", "name shard")):
service (thus identifying it).

"""

def __repr__(self):
return "%s,%d" % (self.name, self.shard)

Expand All @@ -53,6 +57,75 @@ class ConfigError(Exception):
pass


class EphemeralServiceConfig:
"""Configuration of an ephemeral service. An ephemeral service is a
normal service whose shard is chosen depending on its address and
port. The port is assigned inside a range and the address must be
inside the subnet.
"""
EPHEMERAL_SHARD_OFFSET = 10000

def __init__(self, subnet, min_port, max_port):
self.subnet = ipaddress.ip_network(subnet)
self.min_port = min_port
self.max_port = max_port
if min_port > max_port:
raise ConfigError("Invalid port range: [%s, %s]"
% (min_port, max_port))

def get_shard(self, address, port):
"""Get the ephemeral shard for a service given its address and port.

address (IPv4Address|IPv6Address): address of the service.
port (int): port of the service.

return (int): shard of the service
"""
if address not in self.subnet:
raise ValueError("The address is not inside the subnet")
host_id = int(address) & int(self.subnet.hostmask)
num_ports = self.max_port - self.min_port + 1
shard = host_id * num_ports + (port - self.min_port)
return shard + self.EPHEMERAL_SHARD_OFFSET

def get_address(self, shard):
"""Get the address and port of a service given its shard.

shard (int): shard of the service

return (Address): address and port of the service
"""
shard -= self.EPHEMERAL_SHARD_OFFSET
num_ports = self.max_port - self.min_port + 1
port_offset = shard % num_ports
host_id = (shard - port_offset) // num_ports

port = self.min_port + port_offset
addr = self.subnet.network_address + host_id
if addr not in self.subnet:
raise ValueError("The shard is not valid")
return Address(str(addr), port)

def find_free_port(self, address):
"""Find the first open port.

address (IPv4Address|IPv6Address): local address to bind to
"""
if address.version == 4:
family = socket.AF_INET
else:
family = socket.AF_INET6
for port in range(self.min_port, self.max_port+1):
with closing(socket.socket(family, socket.SOCK_STREAM)) as sock:
try:
sock.bind((str(address), port))
return port
except socket.error:
continue
raise ValueError("No free port found in range [%s, %s] "
"for address %s" % (minport, maxport, address))


class AsyncConfig:
"""This class will contain the configuration for the
services. This needs to be populated at the initilization stage.
Expand All @@ -69,6 +142,7 @@ class AsyncConfig:
"""
core_services = {}
other_services = {}
ephemeral_services = {} # type: dict[str, EphemeralServiceConfig]


async_config = AsyncConfig()
Expand All @@ -81,6 +155,7 @@ class Config:
directory for information on the meaning of the fields.

"""

def __init__(self):
"""Default values for configuration, plus decide if this
instance is running from the system path or from the source
Expand Down Expand Up @@ -251,6 +326,19 @@ def _load_unique(self, path):
self.async_config.other_services[coord] = Address(*shard)
del data["other_services"]

if 'ephemeral_services' in data:
for service_name in data['ephemeral_services']:
if service_name.startswith("_"):
continue
service = data["ephemeral_services"][service_name]
self.async_config.ephemeral_services[service_name] = \
EphemeralServiceConfig(
service["subnet"],
service["min_port"],
service["max_port"],
)
del data["ephemeral_services"]

# Put everything else in self.
for key, value in data.items():
setattr(self, key, value)
Expand Down
2 changes: 2 additions & 0 deletions cms/io/web_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self, listen_port, handlers, parameters, shard=0,
if num_proxies_used > 0:
self.wsgi_app = ProxyFix(self.wsgi_app, num_proxies_used)

logger.info("%s listening on '%s' at port %d",
type(self).__name__, listen_address, listen_port)
self.web_server = WSGIServer((listen_address, listen_port), self)

def __call__(self, environ, start_response):
Expand Down
9 changes: 7 additions & 2 deletions cms/server/contest/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from cms.io import WebService
from cms.locale import get_translations
from cms.server.contest.jinja2_toolbox import CWS_ENVIRONMENT
from cms.util import is_shard_ephemeral
from cmscommon.binary import hex_to_bin
from .handlers import HANDLERS
from .handlers.base import ContestListHandler
Expand Down Expand Up @@ -73,8 +74,12 @@ def __init__(self, shard, contest_id=None):
}

try:
listen_address = config.contest_listen_address[shard]
listen_port = config.contest_listen_port[shard]
if is_shard_ephemeral(shard):
index = 0
else:
index = shard
listen_address = config.contest_listen_address[index]
listen_port = config.contest_listen_port[index]
except IndexError:
raise ConfigError("Wrong shard number for %s, or missing "
"address/port configuration. Please check "
Expand Down
15 changes: 14 additions & 1 deletion cms/service/EvaluationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def enqueue(self, item, priority, timestamp):
item_entry = item.to_dict()
del item_entry["testcase_codename"]
item_entry["multiplicity"] = 1
entry = {"item": item_entry, "priority": priority, "timestamp": make_timestamp(timestamp)}
entry = {"item": item_entry, "priority": priority,
"timestamp": make_timestamp(timestamp)}
self.queue_status_cumulative[key] = entry
return success

Expand Down Expand Up @@ -197,6 +198,11 @@ def _remove_from_cumulative_status(self, queue_entry):
if self.queue_status_cumulative[key]["item"]["multiplicity"] == 0:
del self.queue_status_cumulative[key]

def add_worker(self, worker_coord):
"""Add a new worker to the pool.
"""
self.pool.add_worker(worker_coord, ephemeral=True)


def with_post_finish_lock(func):
"""Decorator for locking on self.post_finish_lock.
Expand Down Expand Up @@ -379,6 +385,13 @@ def workers_status(self):
"""
return self.get_executor().pool.get_status()

@rpc_method
def add_worker(self, coord):
"""Register a new worker to the list of workers.
"""
service, shard = coord
self.get_executor().add_worker(ServiceCoord(service, shard))

def check_workers_timeout(self):
"""We ask WorkerPool for the unresponsive workers, and we put
again their operations in the queue.
Expand Down
8 changes: 8 additions & 0 deletions cms/service/Worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import gevent.lock

from cms import ServiceCoord
from cms.db import SessionGen, Contest, enumerate_files
from cms.db.filecacher import FileCacher, TombstoneError
from cms.grading import JobException
Expand Down Expand Up @@ -64,6 +65,13 @@ def __init__(self, shard, fake_worker_time=None):

self._fake_worker_time = fake_worker_time

self.evaluation_service = self.connect_to(
ServiceCoord("EvaluationService", 0),
on_connect=self.on_es_connection)

def on_es_connection(self, address):
self.evaluation_service.add_worker(coord=self._my_coord)

@rpc_method
def precache_files(self, contest_id):
"""RPC to ask the worker to precache of files in the contest.
Expand Down
25 changes: 23 additions & 2 deletions cms/service/workerpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,20 @@ def wait_for_workers(self):
"""Wait until a worker might be available."""
self._workers_available_event.wait()

def add_worker(self, worker_coord):
def add_worker(self, worker_coord, ephemeral=False):
"""Add a new worker to the worker pool.

worker_coord (ServiceCoord): the coordinates of the worker.
ephemeral (bool): remove the worker from the pool after the
disconnection.

"""
shard = worker_coord.shard
# Instruct GeventLibrary to connect ES to the Worker.
self._worker[shard] = self._service.connect_to(
worker_coord,
on_connect=self.on_worker_connected)
on_connect=self.on_worker_connected,
on_disconnect=lambda: self.on_worker_disconnected(worker_coord, ephemeral))

# And we fill all data.
self._operations[shard] = WorkerPool.WORKER_INACTIVE
Expand Down Expand Up @@ -183,6 +186,24 @@ def on_worker_connected(self, worker_coord):
# so we wake up the consumers.
self._workers_available_event.set()

def on_worker_disconnected(self, worker_coord, ephemeral):
"""If the worker is ephemeral, disable and the remove the worker
form the pool.
"""
if not ephemeral:
return
shard = worker_coord.shard
if self._operations[shard] != WorkerPool.WORKER_DISABLED:
# disable the worker and re-enqueue the lost operations
lost_operations = self.disable_worker(shard)
for operation in lost_operations:
logger.info("Operation %s put again in the queue because "
"the worker disconnected.", operation)
priority, timestamp = operation.side_data
self._service.enqueue(operation, priority, timestamp)
del self._worker[shard]
logger.info("Worker %s removed", worker_coord)

def acquire_worker(self, operations):
"""Tries to assign an operation to an available worker. If no workers
are available then this returns None, otherwise this returns
Expand Down
38 changes: 32 additions & 6 deletions cms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import argparse
import itertools
import ipaddress
import logging
import netifaces
import os
Expand All @@ -35,6 +36,7 @@
import gevent.socket

from cms import ServiceCoord, ConfigError, async_config, config
from cms.conf import EphemeralServiceConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,8 +138,19 @@ def get_safe_shard(service, provided_shard):
raise (ValueError): if no safe shard can be returned.

"""
addrs = _find_local_addresses()
# Try to assign an ephemeral shard first. This needs to be done before
# autodetecting the shared using the ip since here we cannot detect if
# the service is already running on that port.
if provided_shard is None and service in config.async_config.ephemeral_services:
ephemeral_config = config.async_config.ephemeral_services[service]
for addr in addrs:
addr = ipaddress.ip_address(addr[1])
if addr in ephemeral_config.subnet:
port = ephemeral_config.find_free_port(addr)
shard = ephemeral_config.get_shard(addr, port)
return shard
if provided_shard is None:
addrs = _find_local_addresses()
computed_shard = _get_shard_from_addresses(service, addrs)
if computed_shard is None:
logger.critical("Couldn't autodetect shard number and "
Expand All @@ -157,17 +170,30 @@ def get_safe_shard(service, provided_shard):
return provided_shard


def is_shard_ephemeral(shard):
"""Checks if the shard is ephemeral.

shard (int): the shard to check.

return (bool): True if the shard is ephemeral.
"""
return shard >= EphemeralServiceConfig.EPHEMERAL_SHARD_OFFSET


def get_service_address(key):
"""Give the Address of a ServiceCoord.

key (ServiceCoord): the service needed.
returns (Address): listening address of key.

"""
service, shard = key
if key in async_config.core_services:
return async_config.core_services[key]
elif key in async_config.other_services:
return async_config.other_services[key]
elif service in async_config.ephemeral_services:
return async_config.ephemeral_services[service].get_address(shard)
else:
raise KeyError("Service not found.")

Expand All @@ -179,11 +205,11 @@ def get_service_shards(service):
returns (int): the number of shards defined in the configuration.

"""
for i in itertools.count():
try:
get_service_address(ServiceCoord(service, i))
except KeyError:
return i
count = 0
for services in (async_config.core_services, async_config.other_services):
count += len([0 for s in services if s.name == service])

return count


def default_argument_parser(description, cls, ask_contest=None):
Expand Down
Loading