diff --git a/elasticsearch_dsl/connections.py b/elasticsearch_dsl/connections.py index 7bfabeb0..d871c1ac 100644 --- a/elasticsearch_dsl/connections.py +++ b/elasticsearch_dsl/connections.py @@ -58,7 +58,7 @@ def add_connection(self, alias, conn): """ Add a connection object, it will be passed through as-is. """ - self._conns[alias] = conn + self._conns[alias] = self._with_user_agent(conn) def remove_connection(self, alias): """ @@ -82,7 +82,7 @@ def create_connection(self, alias="default", **kwargs): """ kwargs.setdefault("serializer", serializer) conn = self._conns[alias] = self.elasticsearch_class(**kwargs) - return conn + return self._with_user_agent(conn) def get_connection(self, alias="default"): """ @@ -96,7 +96,7 @@ def get_connection(self, alias="default"): # do not check isinstance(Elasticsearch) so that people can wrap their # clients if not isinstance(alias, str): - return alias + return self._with_user_agent(alias) # connection already established try: @@ -111,6 +111,21 @@ def get_connection(self, alias="default"): # no connection and no kwargs to set one up raise KeyError(f"There is no connection with alias {alias!r}.") + def _with_user_agent(self, conn): + from . import __versionstr__ # this is here to avoid circular imports + + # try to inject our user agent + if hasattr(conn, "_headers"): + is_frozen = conn._headers.frozen + if is_frozen: + conn._headers = conn._headers.copy() + conn._headers.update( + {"user-agent": f"elasticsearch-dsl-py/{__versionstr__}"} + ) + if is_frozen: + conn._headers.freeze() + return conn + connections = Connections() configure = connections.configure diff --git a/tests/test_connections.py b/tests/test_connections.py index b22e6eba..2b218ce5 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -107,3 +107,29 @@ def test_create_connection_adds_our_serializer(): c_serializers = c.get_connection("testing").transport.serializers assert c_serializers.serializers["application/json"] is serializer.serializer + + +def test_connection_has_correct_user_agent(): + c = connections.Connections(elasticsearch_class=Elasticsearch) + + c.create_connection("testing", hosts=["https://es.com:9200"]) + assert ( + c.get_connection("testing") + ._headers["user-agent"] + .startswith("elasticsearch-dsl-py/") + ) + + my_client = Elasticsearch(hosts=["http://localhost:9200"]) + my_client = my_client.options(headers={"user-agent": "my-user-agent/1.0"}) + c.add_connection("default", my_client) + assert c.get_connection()._headers["user-agent"].startswith("elasticsearch-dsl-py/") + + my_client = Elasticsearch(hosts=["http://localhost:9200"]) + assert ( + c.get_connection(my_client) + ._headers["user-agent"] + .startswith("elasticsearch-dsl-py/") + ) + + not_a_client = object() + assert c.get_connection(not_a_client) == not_a_client