|
| 1 | +from typing import TYPE_CHECKING |
| 2 | + |
| 3 | +from requests import ConnectionError, get |
| 4 | + |
| 5 | +from testcontainers.core.container import DockerContainer |
| 6 | +from testcontainers.core.utils import raise_for_deprecated_parameter |
| 7 | +from testcontainers.core.waiting_utils import wait_container_is_ready |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from requests import Response |
| 11 | + |
| 12 | + |
| 13 | +class ChromaContainer(DockerContainer): |
| 14 | + """ |
| 15 | + The example below spins up a ChromaDB container, performs a healthcheck and creates a collection. |
| 16 | + The method :code:`get_client` can be used to create a client for the Chroma Python Client. |
| 17 | +
|
| 18 | + Example: |
| 19 | +
|
| 20 | + .. doctest:: |
| 21 | +
|
| 22 | + >>> import chromadb |
| 23 | + >>> from testcontainers.chroma import ChromaContainer |
| 24 | +
|
| 25 | + >>> with ChromaContainer() as chroma: |
| 26 | + ... config = chroma.get_config() |
| 27 | + ... client = chromadb.HttpClient(host=config["host"], port=config["port"]) |
| 28 | + ... col = client.get_or_create_collection("test") |
| 29 | + ... col.name |
| 30 | + 'test' |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + image: str = "chromadb/chroma:latest", |
| 36 | + port: int = 8000, |
| 37 | + **kwargs, |
| 38 | + ) -> None: |
| 39 | + """ |
| 40 | + Args: |
| 41 | + image: Docker image to use for the MinIO container. |
| 42 | + port: Port to expose on the container. |
| 43 | + access_key: Access key for client connections. |
| 44 | + secret_key: Secret key for client connections. |
| 45 | + """ |
| 46 | + raise_for_deprecated_parameter(kwargs, "port_to_expose", "port") |
| 47 | + super().__init__(image, **kwargs) |
| 48 | + self.port = port |
| 49 | + |
| 50 | + self.with_exposed_ports(self.port) |
| 51 | + # self.with_command(f"server /data --address :{self.port}") |
| 52 | + |
| 53 | + def get_config(self) -> dict: |
| 54 | + """This method returns the configuration of the Chroma container, |
| 55 | + including the endpoint. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + dict: {`endpoint`: str} |
| 59 | + """ |
| 60 | + host_ip = self.get_container_host_ip() |
| 61 | + exposed_port = self.get_exposed_port(self.port) |
| 62 | + return { |
| 63 | + "endpoint": f"{host_ip}:{exposed_port}", |
| 64 | + "host": host_ip, |
| 65 | + "port": exposed_port, |
| 66 | + } |
| 67 | + |
| 68 | + @wait_container_is_ready(ConnectionError) |
| 69 | + def _healthcheck(self) -> None: |
| 70 | + """This is an internal method used to check if the Chroma container |
| 71 | + is healthy and ready to receive requests.""" |
| 72 | + url = f"http://{self.get_config()['endpoint']}/api/v1/heartbeat" |
| 73 | + response: Response = get(url) |
| 74 | + response.raise_for_status() |
| 75 | + |
| 76 | + def start(self) -> "ChromaContainer": |
| 77 | + """This method starts the Chroma container and runs the healthcheck |
| 78 | + to verify that the container is ready to use.""" |
| 79 | + super().start() |
| 80 | + self._healthcheck() |
| 81 | + return self |
0 commit comments