From 5d0d9bef4c25f3f9a8950264e2c97e2e3c902bcf Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 12 Apr 2024 16:41:35 +0100 Subject: [PATCH] Added inner_hits option to kNN search --- elasticsearch_dsl/search_base.py | 4 ++++ tests/_async/test_search.py | 2 ++ tests/_sync/test_search.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index c8816d59..ee6f49bf 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -514,6 +514,7 @@ def knn( boost=None, filter=None, similarity=None, + inner_hits=None, ): """ Add a k-nearest neighbor (kNN) search. @@ -526,6 +527,7 @@ def knn( :arg boost: A floating-point boost factor for kNN scores :arg filter: query to filter the documents that can match :arg similarity: the minimum similarity required for a document to be considered a match, as a float value + :arg inner_hits: retrieve hits from nested field Example:: @@ -560,6 +562,8 @@ def knn( s._knn[-1]["filter"] = filter if similarity is not None: s._knn[-1]["similarity"] = similarity + if inner_hits is not None: + s._knn[-1]["inner_hits"] = inner_hits return s def rank(self, rrf=None): diff --git a/tests/_async/test_search.py b/tests/_async/test_search.py index 6abc8bca..aabbce6f 100644 --- a/tests/_async/test_search.py +++ b/tests/_async/test_search.py @@ -266,6 +266,7 @@ def test_knn(): query_vector_builder={ "text_embedding": {"model_id": "foo", "model_text": "search text"} }, + inner_hits={"size": 1}, ) assert { "knn": [ @@ -283,6 +284,7 @@ def test_knn(): "text_embedding": {"model_id": "foo", "model_text": "search text"} }, "boost": 0.8, + "inner_hits": {"size": 1}, }, ] } == s.to_dict() diff --git a/tests/_sync/test_search.py b/tests/_sync/test_search.py index 255b1eeb..001ce704 100644 --- a/tests/_sync/test_search.py +++ b/tests/_sync/test_search.py @@ -266,6 +266,7 @@ def test_knn(): query_vector_builder={ "text_embedding": {"model_id": "foo", "model_text": "search text"} }, + inner_hits={"size": 1}, ) assert { "knn": [ @@ -283,6 +284,7 @@ def test_knn(): "text_embedding": {"model_id": "foo", "model_text": "search text"} }, "boost": 0.8, + "inner_hits": {"size": 1}, }, ] } == s.to_dict()