diff --git a/modules/chroma/README.rst b/modules/chroma/README.rst new file mode 100644 index 00000000..f4e3199f --- /dev/null +++ b/modules/chroma/README.rst @@ -0,0 +1,2 @@ +.. autoclass:: testcontainers.chroma.ChromaContainer +.. title:: testcontainers.minio.ChromaContainer diff --git a/modules/chroma/testcontainers/chroma/__init__.py b/modules/chroma/testcontainers/chroma/__init__.py new file mode 100644 index 00000000..d4a1e6cb --- /dev/null +++ b/modules/chroma/testcontainers/chroma/__init__.py @@ -0,0 +1,79 @@ +from typing import TYPE_CHECKING + +from requests import ConnectionError, get + +from testcontainers.core.container import DockerContainer +from testcontainers.core.utils import raise_for_deprecated_parameter +from testcontainers.core.waiting_utils import wait_container_is_ready + +if TYPE_CHECKING: + from requests import Response + + +class ChromaContainer(DockerContainer): + """ + The example below spins up a ChromaDB container, performs a healthcheck and creates a collection. + The method :code:`get_client` can be used to create a client for the Chroma Python Client. + + Example: + + .. doctest:: + + >>> import io + >>> from testcontainers.chroma import ChromaContainer + + >>> with ChromaContainer() as chroma: + ... client = chroma.get_client() + ... client.heartbeat() + ... client.get_or_create_collection("test") + """ + + def __init__( + self, + image: str = "chromadb/chroma:latest", + port: int = 8000, + **kwargs, + ) -> None: + """ + Args: + image: Docker image to use for the MinIO container. + port: Port to expose on the container. + access_key: Access key for client connections. + secret_key: Secret key for client connections. + """ + raise_for_deprecated_parameter(kwargs, "port_to_expose", "port") + super().__init__(image, **kwargs) + self.port = port + + self.with_exposed_ports(self.port) + # self.with_command(f"server /data --address :{self.port}") + + def get_config(self) -> dict: + """This method returns the configuration of the Chroma container, + including the endpoint. + + Returns: + dict: {`endpoint`: str} + """ + host_ip = self.get_container_host_ip() + exposed_port = self.get_exposed_port(self.port) + return { + "endpoint": f"{host_ip}:{exposed_port}", + "host": host_ip, + "port": exposed_port, + } + + @wait_container_is_ready(ConnectionError) + def _healthcheck(self) -> None: + """This is an internal method used to check if the Chroma container + is healthy and ready to receive requests.""" + url = f"http://{self.get_config()['endpoint']}/api/v1/heartbeat" + response: Response = get(url) + response.raise_for_status() + + def start(self) -> "ChromaContainer": + """This method starts the Chroma container and runs the healthcheck + to verify that the container is ready to use.""" + super().start() + self._healthcheck() + return self diff --git a/modules/chroma/tests/test_chroma.py b/modules/chroma/tests/test_chroma.py new file mode 100644 index 00000000..fee55b78 --- /dev/null +++ b/modules/chroma/tests/test_chroma.py @@ -0,0 +1,9 @@ +from testcontainers.chroma import ChromaContainer +import chromadb + + +def test_docker_run_chroma(): + with ChromaContainer(image="chromadb/chroma:0.4.24") as chroma: + client = chromadb.HttpClient(host=chroma.get_config()["host"], port=chroma.get_config()["port"]) + col = client.get_or_create_collection("test") + assert col.name == "test" diff --git a/pyproject.toml b/pyproject.toml index c48dce7e..2d18f114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ packages = [ { include = "testcontainers", from = "modules/postgres" }, { include = "testcontainers", from = "modules/rabbitmq" }, { include = "testcontainers", from = "modules/redis" }, - { include = "testcontainers", from = "modules/selenium" } + { include = "testcontainers", from = "modules/selenium" }, + { include = "testcontainers", from = "modules/chroma" }, ] [tool.poetry.urls] @@ -84,6 +85,7 @@ cx_Oracle = { version = "*", optional = true } pika = { version = "*", optional = true } redis = { version = "*", optional = true } selenium = { version = "*", optional = true } +chroma = { version = "*", optional = true } [tool.poetry.extras] arangodb = ["python-arango"] @@ -108,6 +110,7 @@ postgres = [] rabbitmq = ["pika"] redis = ["redis"] selenium = ["selenium"] +chroma = ["chromadb-client"] [tool.poetry.group.dev.dependencies] mypy = "1.7.1"