From c015c2c8f91c775f39a55cbadf2b88aa768ba4dc Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 8 Mar 2022 02:16:03 +0200 Subject: [PATCH] Add pipeline support for search --- redis/commands/search/__init__.py | 30 +++++++++++++++++++++++++----- redis/commands/search/commands.py | 31 ++++++++++++++++++++----------- tests/test_search.py | 18 ++++++++++++++++++ 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 94bc037c3d..e9763b60a8 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -1,3 +1,5 @@ +import redis + from .commands import SearchCommands @@ -17,7 +19,7 @@ def __init__(self, client, chunk_size=1000): self.client = client self.execute_command = client.execute_command - self.pipeline = client.pipeline(transaction=False, shard_hint=None) + self._pipeline = client.pipeline(transaction=False, shard_hint=None) self.total = 0 self.chunk_size = chunk_size self.current_chunk = 0 @@ -42,7 +44,7 @@ def add_document( """ self.client._add_document( doc_id, - conn=self.pipeline, + conn=self._pipeline, nosave=nosave, score=score, payload=payload, @@ -67,7 +69,7 @@ def add_document_hash( """ self.client._add_document_hash( doc_id, - conn=self.pipeline, + conn=self._pipeline, score=score, replace=replace, ) @@ -80,7 +82,7 @@ def commit(self): """ Manually commit and flush the batch indexing query """ - self.pipeline.execute() + self._pipeline.execute() self.current_chunk = 0 def __init__(self, client, index_name="idx"): @@ -90,7 +92,25 @@ def __init__(self, client, index_name="idx"): If conn is not None, we employ an already existing redis connection """ + self.MODULE_CALLBACKS = {} self.client = client self.index_name = index_name self.execute_command = client.execute_command - self.pipeline = client.pipeline + self._pipeline = client.pipeline + + def pipeline(self, transaction=True, shard_hint=None): + """Creates a pipeline for the SEARCH module, that can be used for executing + SEARCH commands, as well as classic core commands. + """ + p = Pipeline( + connection_pool=self.client.connection_pool, + response_callbacks=self.MODULE_CALLBACKS, + transaction=transaction, + shard_hint=shard_hint, + ) + p.index_name = self.index_name + return p + + +class Pipeline(SearchCommands, redis.client.Pipeline): + """Pipeline for the module.""" diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 39c599fbd5..158beec0de 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -2,6 +2,8 @@ import time from typing import Dict, Union +from redis.client import Pipeline + from ..helpers import parse_to_dict from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor @@ -186,8 +188,6 @@ def _add_document( """ Internal add_document used for both batch and single doc indexing """ - if conn is None: - conn = self.client if partial or no_create: replace = True @@ -208,7 +208,11 @@ def _add_document( args += ["LANGUAGE", language] args.append("FIELDS") args += list(itertools.chain(*fields.items())) - return conn.execute_command(*args) + + if conn is not None: + return conn.execute_command(*args) + + return self.execute_command(*args) def _add_document_hash( self, @@ -221,8 +225,6 @@ def _add_document_hash( """ Internal add_document_hash used for both batch and single doc indexing """ - if conn is None: - conn = self.client args = [ADDHASH_CMD, self.index_name, doc_id, score] @@ -232,7 +234,10 @@ def _add_document_hash( if language: args += ["LANGUAGE", language] - return conn.execute_command(*args) + if conn is not None: + return conn.execute_command(*args) + + return self.execute_command(*args) def add_document( self, @@ -331,12 +336,13 @@ def delete_document(self, doc_id, conn=None, delete_actual_document=False): For more information: https://oss.redis.com/redisearch/Commands/#ftdel """ # noqa args = [DEL_CMD, self.index_name, doc_id] - if conn is None: - conn = self.client if delete_actual_document: args.append("DD") - return conn.execute_command(*args) + if conn is not None: + return conn.execute_command(*args) + + return self.execute_command(*args) def load_document(self, id): """ @@ -364,7 +370,7 @@ def get(self, *ids): For more information https://oss.redis.com/redisearch/Commands/#ftget """ - return self.client.execute_command(MGET_CMD, self.index_name, *ids) + return self.execute_command(MGET_CMD, self.index_name, *ids) def info(self): """ @@ -374,7 +380,7 @@ def info(self): For more information https://oss.redis.com/redisearch/Commands/#ftinfo """ - res = self.client.execute_command(INFO_CMD, self.index_name) + res = self.execute_command(INFO_CMD, self.index_name) it = map(to_string, res) return dict(zip(it, it)) @@ -423,6 +429,9 @@ def search( st = time.time() res = self.execute_command(SEARCH_CMD, *args) + if isinstance(res, Pipeline): + return res + return Result( res, not query._no_content, diff --git a/tests/test_search.py b/tests/test_search.py index 5ee17a2c36..9c879d0321 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1574,3 +1574,21 @@ def test_geo_params(modclient): assert "doc1" == res.docs[0].id assert "doc2" == res.docs[1].id assert "doc3" == res.docs[2].id + + +@pytest.mark.redismod +def test_search_commands_in_pipeline(client): + p = client.ft().pipeline() + p.create_index((TextField("txt"),)) + p.add_document("doc1", payload="foo baz", txt="foo bar") + p.add_document("doc2", txt="foo bar") + q = Query("foo bar").with_payloads() + p.search(q) + res = p.execute() + assert res[:3] == ["OK", "OK", "OK"] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert "foo baz" == res[3][2] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"]