diff --git a/core/testcontainers/core/version.py b/core/testcontainers/core/version.py new file mode 100644 index 000000000..cac51fc18 --- /dev/null +++ b/core/testcontainers/core/version.py @@ -0,0 +1,30 @@ +from typing import Callable + +from packaging.version import Version + + +class ComparableVersion: + def __init__(self, version): + self.version = Version(version) + + def __lt__(self, other: str): + return self._apply_op(other, lambda x, y: x < y) + + def __le__(self, other: str): + return self._apply_op(other, lambda x, y: x <= y) + + def __eq__(self, other: str): + return self._apply_op(other, lambda x, y: x == y) + + def __ne__(self, other: str): + return self._apply_op(other, lambda x, y: x != y) + + def __gt__(self, other: str): + return self._apply_op(other, lambda x, y: x > y) + + def __ge__(self, other: str): + return self._apply_op(other, lambda x, y: x >= y) + + def _apply_op(self, other: str, op: Callable[[Version, Version], bool]): + other = Version(other) + return op(self.version, other) diff --git a/core/tests/test_version.py b/core/tests/test_version.py new file mode 100644 index 000000000..397cd0523 --- /dev/null +++ b/core/tests/test_version.py @@ -0,0 +1,78 @@ +import pytest +from packaging.version import InvalidVersion + +from testcontainers.core.version import ComparableVersion + + +@pytest.fixture +def version(): + return ComparableVersion("1.0.0") + + +@pytest.mark.parametrize("other_version, expected", [("0.9.0", False), ("1.0.0", False), ("1.1.0", True)]) +def test_lt(version, other_version, expected): + assert (version < other_version) == expected + + +@pytest.mark.parametrize("other_version, expected", [("0.9.0", False), ("1.0.0", True), ("1.1.0", True)]) +def test_le(version, other_version, expected): + assert (version <= other_version) == expected + + +@pytest.mark.parametrize("other_version, expected", [("0.9.0", False), ("1.0.0", True), ("1.1.0", False)]) +def test_eq(version, other_version, expected): + assert (version == other_version) == expected + + +@pytest.mark.parametrize("other_version, expected", [("0.9.0", True), ("1.0.0", False), ("1.1.0", True)]) +def test_ne(version, other_version, expected): + assert (version != other_version) == expected + + +@pytest.mark.parametrize("other_version, expected", [("0.9.0", True), ("1.0.0", False), ("1.1.0", False)]) +def test_gt(version, other_version, expected): + assert (version > other_version) == expected + + +@pytest.mark.parametrize("other_version, expected", [("0.9.0", True), ("1.0.0", True), ("1.1.0", False)]) +def test_ge(version, other_version, expected): + assert (version >= other_version) == expected + + +@pytest.mark.parametrize( + "invalid_version", + [ + "invalid", + "1..0", + ], +) +def test_invalid_version_raises_error(invalid_version): + with pytest.raises(InvalidVersion): + ComparableVersion(invalid_version) + + +@pytest.mark.parametrize( + "invalid_version", + [ + "invalid", + "1..0", + ], +) +def test_comparison_with_invalid_version_raises_error(version, invalid_version): + with pytest.raises(InvalidVersion): + assert version < invalid_version + + with pytest.raises(InvalidVersion): + assert version <= invalid_version + + with pytest.raises(InvalidVersion): + assert version == invalid_version + + with pytest.raises(InvalidVersion): + assert version != invalid_version + + with pytest.raises(InvalidVersion): + assert version > invalid_version + + with pytest.raises(InvalidVersion): + assert version >= invalid_version diff --git a/modules/kafka/testcontainers/kafka/__init__.py b/modules/kafka/testcontainers/kafka/__init__.py index 7dd71b633..ea837be37 100644 --- a/modules/kafka/testcontainers/kafka/__init__.py +++ b/modules/kafka/testcontainers/kafka/__init__.py @@ -3,8 +3,11 @@ from io import BytesIO from textwrap import dedent +from typing_extensions import Self + from testcontainers.core.container import DockerContainer from testcontainers.core.utils import raise_for_deprecated_parameter +from testcontainers.core.version import ComparableVersion from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.kafka._redpanda import RedpandaContainer @@ -26,18 +29,29 @@ class KafkaContainer(DockerContainer): >>> with KafkaContainer() as kafka: ... connection = kafka.get_bootstrap_server() + + # Using KRaft protocol + >>> with KafkaContainer().with_kraft() as kafka: + ... connection = kafka.get_bootstrap_server() """ TC_START_SCRIPT = "/tc-start.sh" + MIN_KRAFT_TAG = "7.0.0" def __init__(self, image: str = "confluentinc/cp-kafka:7.6.0", port: int = 9093, **kwargs) -> None: raise_for_deprecated_parameter(kwargs, "port_to_expose", "port") super().__init__(image, **kwargs) self.port = port + self.kraft_enabled = False + self.wait_for = r".*\[KafkaServer id=\d+\] started.*" + self.boot_command = "" + self.cluster_id = "MkU3OEVBNTcwNTJENDM2Qk" + self.listeners = f"PLAINTEXT://0.0.0.0:{self.port},BROKER://0.0.0.0:9092" + self.security_protocol_map = "BROKER:PLAINTEXT,PLAINTEXT:PLAINTEXT" + self.with_exposed_ports(self.port) - listeners = f"PLAINTEXT://0.0.0.0:{self.port},BROKER://0.0.0.0:9092" - self.with_env("KAFKA_LISTENERS", listeners) - self.with_env("KAFKA_LISTENER_SECURITY_PROTOCOL_MAP", "BROKER:PLAINTEXT,PLAINTEXT:PLAINTEXT") + self.with_env("KAFKA_LISTENERS", self.listeners) + self.with_env("KAFKA_LISTENER_SECURITY_PROTOCOL_MAP", self.security_protocol_map) self.with_env("KAFKA_INTER_BROKER_LISTENER_NAME", "BROKER") self.with_env("KAFKA_BROKER_ID", "1") @@ -46,6 +60,74 @@ def __init__(self, image: str = "confluentinc/cp-kafka:7.6.0", port: int = 9093, self.with_env("KAFKA_LOG_FLUSH_INTERVAL_MESSAGES", "10000000") self.with_env("KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS", "0") + def with_kraft(self) -> Self: + self._verify_min_kraft_version() + self.kraft_enabled = True + return self + + def _verify_min_kraft_version(self): + actual_version = self.image.split(":")[-1] + + if ComparableVersion(actual_version) < self.MIN_KRAFT_TAG: + raise ValueError( + f"Provided Confluent Platform's version {actual_version} " + f"is not supported in Kraft mode" + f" (must be {self.MIN_KRAFT_TAG} or above)" + ) + + def with_cluster_id(self, cluster_id: str) -> Self: + self.cluster_id = cluster_id + return self + + def configure(self): + if self.kraft_enabled: + self._configure_kraft() + else: + self._configure_zookeeper() + + def _configure_kraft(self) -> None: + self.wait_for = r".*Kafka Server started.*" + + self.with_env("CLUSTER_ID", self.cluster_id) + self.with_env("KAFKA_NODE_ID", 1) + self.with_env( + "KAFKA_LISTENER_SECURITY_PROTOCOL_MAP", + f"{self.security_protocol_map},CONTROLLER:PLAINTEXT", + ) + self.with_env( + "KAFKA_LISTENERS", + f"{self.listeners},CONTROLLER://0.0.0.0:9094", + ) + self.with_env("KAFKA_PROCESS_ROLES", "broker,controller") + + network_alias = self._get_network_alias() + controller_quorum_voters = f"1@{network_alias}:9094" + self.with_env("KAFKA_CONTROLLER_QUORUM_VOTERS", controller_quorum_voters) + self.with_env("KAFKA_CONTROLLER_LISTENER_NAMES", "CONTROLLER") + + self.boot_command = f""" + sed -i '/KAFKA_ZOOKEEPER_CONNECT/d' /etc/confluent/docker/configure + echo 'kafka-storage format --ignore-formatted -t {self.cluster_id} -c /etc/kafka/kafka.properties' >> /etc/confluent/docker/configure + """ + + def _get_network_alias(self): + if self._network: + return next( + iter(self._network_aliases or [self._network.name or self._kwargs.get("network", [])]), + None, + ) + + return "localhost" + + def _configure_zookeeper(self) -> None: + self.boot_command = """ + echo 'clientPort=2181' > zookeeper.properties + echo 'dataDir=/var/lib/zookeeper/data' >> zookeeper.properties + echo 'dataLogDir=/var/lib/zookeeper/log' >> zookeeper.properties + zookeeper-server-start zookeeper.properties & + export KAFKA_ZOOKEEPER_CONNECT='localhost:2181' + """ + def get_bootstrap_server(self) -> str: host = self.get_container_host_ip() port = self.get_exposed_port(self.port) @@ -59,11 +141,7 @@ def tc_start(self) -> None: dedent( f""" #!/bin/bash - echo 'clientPort=2181' > zookeeper.properties - echo 'dataDir=/var/lib/zookeeper/data' >> zookeeper.properties - echo 'dataLogDir=/var/lib/zookeeper/log' >> zookeeper.properties - zookeeper-server-start zookeeper.properties & - export KAFKA_ZOOKEEPER_CONNECT='localhost:2181' + {self.boot_command} export KAFKA_ADVERTISED_LISTENERS={listeners} . /etc/confluent/docker/bash-config /etc/confluent/docker/configure @@ -78,10 +156,11 @@ def tc_start(self) -> None: def start(self, timeout=30) -> "KafkaContainer": script = KafkaContainer.TC_START_SCRIPT command = f'sh -c "while [ ! -f {script} ]; do sleep 0.1; done; sh {script}"' + self.configure() self.with_command(command) super().start() self.tc_start() - wait_for_logs(self, r".*\[KafkaServer id=\d+\] started.*", timeout=timeout) + wait_for_logs(self, self.wait_for, timeout=timeout) return self def create_file(self, content: bytes, path: str) -> None: diff --git a/modules/kafka/tests/test_kafka.py b/modules/kafka/tests/test_kafka.py index 1f3826adf..eb1a48127 100644 --- a/modules/kafka/tests/test_kafka.py +++ b/modules/kafka/tests/test_kafka.py @@ -8,6 +8,12 @@ def test_kafka_producer_consumer(): produce_and_consume_kafka_message(container) +def test_kafka_with_kraft_producer_consumer(): + with KafkaContainer().with_kraft() as container: + assert container.kraft_enabled + produce_and_consume_kafka_message(container) + + def test_kafka_producer_consumer_custom_port(): with KafkaContainer(port=9888) as container: assert container.port == 9888