diff --git a/dev_requirements.txt b/dev_requirements.txt index 37a107d16d..adfa99e80c 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -9,7 +9,7 @@ packaging>=20.4 pytest pytest-asyncio>=0.23.0,<0.24.0 pytest-cov -pytest-profiling +pytest-profiling==1.7.0 pytest-timeout ujson>=4.2.0 uvloop diff --git a/doctests/query_agg.py b/doctests/query_agg.py new file mode 100644 index 0000000000..4fa8f14b84 --- /dev/null +++ b/doctests/query_agg.py @@ -0,0 +1,103 @@ +# EXAMPLE: query_agg +# HIDE_START +import json +import redis +from redis.commands.json.path import Path +from redis.commands.search import Search +from redis.commands.search.aggregation import AggregateRequest +from redis.commands.search.field import NumericField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +import redis.commands.search.reducers as reducers + +r = redis.Redis(decode_responses=True) + +# create index +schema = ( + TagField("$.condition", as_name="condition"), + NumericField("$.price", as_name="price"), +) + +index = r.ft("idx:bicycle") +index.create_index( + schema, + definition=IndexDefinition(prefix=["bicycle:"], index_type=IndexType.JSON), +) + +# load data +with open("data/query_em.json") as f: + bicycles = json.load(f) + +pipeline = r.pipeline(transaction=False) +for bid, bicycle in enumerate(bicycles): + pipeline.json().set(f'bicycle:{bid}', Path.root_path(), bicycle) +pipeline.execute() +# HIDE_END + +# STEP_START agg1 +search = Search(r, index_name="idx:bicycle") +aggregate_request = AggregateRequest(query='@condition:{new}') \ + .load('__key', 'price') \ + .apply(discounted='@price - (@price * 0.1)') +res = search.aggregate(aggregate_request) +print(len(res.rows)) # >>> 5 +print(res.rows) # >>> [['__key', 'bicycle:0', ... +#[['__key', 'bicycle:0', 'price', '270', 'discounted', '243'], +# ['__key', 'bicycle:5', 'price', '810', 'discounted', '729'], +# ['__key', 'bicycle:6', 'price', '2300', 'discounted', '2070'], +# ['__key', 'bicycle:7', 'price', '430', 'discounted', '387'], +# ['__key', 'bicycle:8', 'price', '1200', 'discounted', '1080']] +# REMOVE_START +assert len(res.rows) == 5 +# REMOVE_END +# STEP_END + +# STEP_START agg2 +search = Search(r, index_name="idx:bicycle") +aggregate_request = AggregateRequest(query='*') \ + .load('price') \ + .apply(price_category='@price<1000') \ + .group_by('@condition', reducers.sum('@price_category').alias('num_affordable')) +res = search.aggregate(aggregate_request) +print(len(res.rows)) # >>> 3 +print(res.rows) # >>> +#[['condition', 'refurbished', 'num_affordable', '1'], +# ['condition', 'used', 'num_affordable', '1'], +# ['condition', 'new', 'num_affordable', '3']] +# REMOVE_START +assert len(res.rows) == 3 +# REMOVE_END +# STEP_END + +# STEP_START agg3 +search = Search(r, index_name="idx:bicycle") +aggregate_request = AggregateRequest(query='*') \ + .apply(type="'bicycle'") \ + .group_by('@type', reducers.count().alias('num_total')) +res = search.aggregate(aggregate_request) +print(len(res.rows)) # >>> 1 +print(res.rows) # >>> [['type', 'bicycle', 'num_total', '10']] +# REMOVE_START +assert len(res.rows) == 1 +# REMOVE_END +# STEP_END + +# STEP_START agg4 +search = Search(r, index_name="idx:bicycle") +aggregate_request = AggregateRequest(query='*') \ + .load('__key') \ + .group_by('@condition', reducers.tolist('__key').alias('bicycles')) +res = search.aggregate(aggregate_request) +print(len(res.rows)) # >>> 3 +print(res.rows) # >>> +#[['condition', 'refurbished', 'bicycles', ['bicycle:9']], +# ['condition', 'used', 'bicycles', ['bicycle:1', 'bicycle:2', 'bicycle:3', 'bicycle:4']], +# ['condition', 'new', 'bicycles', ['bicycle:5', 'bicycle:6', 'bicycle:7', 'bicycle:0', 'bicycle:8']]] +# REMOVE_START +assert len(res.rows) == 3 +# REMOVE_END +# STEP_END + +# REMOVE_START +# destroy index and data +r.ft("idx:bicycle").dropindex(delete_documents=True) +# REMOVE_END diff --git a/doctests/query_combined.py b/doctests/query_combined.py new file mode 100644 index 0000000000..a17f19417c --- /dev/null +++ b/doctests/query_combined.py @@ -0,0 +1,124 @@ +# EXAMPLE: query_combined +# HIDE_START +import json +import numpy as np +import redis +import warnings +from redis.commands.json.path import Path +from redis.commands.search.field import NumericField, TagField, TextField, VectorField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +from sentence_transformers import SentenceTransformer + + +def embed_text(model, text): + return np.array(model.encode(text)).astype(np.float32).tobytes() + +warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces.*") +model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') +query = "Bike for small kids" +query_vector = embed_text(model, query) + +r = redis.Redis(decode_responses=True) + +# create index +schema = ( + TextField("$.description", no_stem=True, as_name="model"), + TagField("$.condition", as_name="condition"), + NumericField("$.price", as_name="price"), + VectorField( + "$.description_embeddings", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": 384, + "DISTANCE_METRIC": "COSINE", + }, + as_name="vector", + ), +) + +index = r.ft("idx:bicycle") +index.create_index( + schema, + definition=IndexDefinition(prefix=["bicycle:"], index_type=IndexType.JSON), +) + +# load data +with open("data/query_vector.json") as f: + bicycles = json.load(f) + +pipeline = r.pipeline(transaction=False) +for bid, bicycle in enumerate(bicycles): + pipeline.json().set(f'bicycle:{bid}', Path.root_path(), bicycle) +pipeline.execute() +# HIDE_END + +# STEP_START combined1 +q = Query("@price:[500 1000] @condition:{new}") +res = index.search(q) +print(res.total) # >>> 1 +# REMOVE_START +assert res.total == 1 +# REMOVE_END +# STEP_END + +# STEP_START combined2 +q = Query("kids @price:[500 1000] @condition:{used}") +res = index.search(q) +print(res.total) # >>> 1 +# REMOVE_START +assert res.total == 1 +# REMOVE_END +# STEP_END + +# STEP_START combined3 +q = Query("(kids | small) @condition:{used}") +res = index.search(q) +print(res.total) # >>> 2 +# REMOVE_START +assert res.total == 2 +# REMOVE_END +# STEP_END + +# STEP_START combined4 +q = Query("@description:(kids | small) @condition:{used}") +res = index.search(q) +print(res.total) # >>> 0 +# REMOVE_START +assert res.total == 0 +# REMOVE_END +# STEP_END + +# STEP_START combined5 +q = Query("@description:(kids | small) @condition:{new | used}") +res = index.search(q) +print(res.total) # >>> 0 +# REMOVE_START +assert res.total == 0 +# REMOVE_END +# STEP_END + +# STEP_START combined6 +q = Query("@price:[500 1000] -@condition:{new}") +res = index.search(q) +print(res.total) # >>> 2 +# REMOVE_START +assert res.total == 2 +# REMOVE_END +# STEP_END + +# STEP_START combined7 +q = Query("(@price:[500 1000] -@condition:{new})=>[KNN 3 @vector $query_vector]").dialect(2) +# put query string here +res = index.search(q,{ 'query_vector': query_vector }) +print(res.total) # >>> 2 +# REMOVE_START +assert res.total == 2 +# REMOVE_END +# STEP_END + +# REMOVE_START +# destroy index and data +r.ft("idx:bicycle").dropindex(delete_documents=True) +# REMOVE_END diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 42c3547b0b..5638f1d662 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None: self._cursor = [] self._dialect = None self._add_scores = False + self._scorer = "TFIDF" def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest": self._add_scores = True return self + def scorer(self, scorer: str) -> "AggregateRequest": + """ + Use a different scoring function to evaluate document relevance. + Default is `TFIDF`. + + :param scorer: The scoring function to use + (e.g. `TFIDF.DOCNORM` or `BM25`) + """ + self._scorer = scorer + return self + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self @@ -323,6 +335,9 @@ def build_args(self) -> List[str]: if self._verbatim: ret.append("VERBATIM") + if self._scorer: + ret.extend(["SCORER", self._scorer]) + if self._add_scores: ret.append("ADDSCORES") @@ -332,6 +347,7 @@ def build_args(self) -> List[str]: if self._loadall: ret.append("LOAD") ret.append("*") + elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 0e6fe22131..fb813b0bc7 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): assert res.rows[1] == ["__score", "0.2"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + assert await decoded_r.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "an animal book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + assert await decoded_r.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "an animal book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = await decoded_r.ft().aggregate( + req, + query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()}, + ) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 + + @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index dde59f0f87..0f0e7bb309 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1466,6 +1466,61 @@ def test_aggregations_add_scores(client): assert res.rows[1] == ["__score", "0.2"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_hybrid_scoring(client): + client.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + client.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "an animal book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + client.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "an animal book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = client.ft().aggregate( + req, + query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()}, + ) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 + + @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client):