Skip to content

Commit 214caf3

Browse files
Support FT.PROFILE + test
1 parent da0b3fa commit 214caf3

File tree

3 files changed

+113
-6
lines changed

3 files changed

+113
-6
lines changed

redis/commands/helpers.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ def delist(x):
3535

3636

3737
def parse_to_list(response):
38-
"""Optimistally parse the response to a list.
39-
"""
38+
"""Optimistically parse the response to a list."""
4039
res = []
4140
for item in response:
4241
try:
@@ -51,6 +50,40 @@ def parse_to_list(response):
5150
return res
5251

5352

53+
def parse_list_to_dict(response):
54+
res = {}
55+
for i in range(0, len(response), 2):
56+
if isinstance(response[i], list):
57+
res['Child iterators'].append(parse_list_to_dict(response[i]))
58+
elif isinstance(response[i+1], list):
59+
res['Child iterators'] = [parse_list_to_dict(response[i+1])]
60+
else:
61+
try:
62+
res[response[i]] = float(response[i+1])
63+
except (TypeError, ValueError):
64+
res[response[i]] = response[i+1]
65+
return res
66+
67+
68+
def parse_to_dict(response):
69+
if response is None:
70+
return {}
71+
72+
res = {}
73+
for det in response:
74+
if isinstance(det[1], list):
75+
res[det[0]] = parse_list_to_dict(det[1])
76+
else:
77+
try: # try to set the attribute. may be provided without value
78+
try: # try to convert the value to float
79+
res[det[0]] = float(det[1])
80+
except (TypeError, ValueError):
81+
res[det[0]] = det[1]
82+
except IndexError:
83+
pass
84+
return res
85+
86+
5487
def random_string(length=10):
5588
"""
5689
Returns a random N character long string.

redis/commands/search/commands.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._util import to_string
88
from .aggregation import AggregateRequest, AggregateResult, Cursor
99
from .suggestion import SuggestionParser
10+
from ..helpers import parse_to_dict
1011

1112
NUMERIC = "NUMERIC"
1213

@@ -20,6 +21,7 @@
2021
EXPLAINCLI_CMD = "FT.EXPLAINCLI"
2122
DEL_CMD = "FT.DEL"
2223
AGGREGATE_CMD = "FT.AGGREGATE"
24+
PROFILE_CMD = "FT.PROFILE"
2325
CURSOR_CMD = "FT.CURSOR"
2426
SPELLCHECK_CMD = "FT.SPELLCHECK"
2527
DICT_ADD_CMD = "FT.DICTADD"
@@ -382,11 +384,11 @@ def explain_cli(self, query): # noqa
382384

383385
def aggregate(self, query):
384386
"""
385-
Issue an aggregation query
387+
Issue an aggregation query.
386388
387389
### Parameters
388390
389-
**query**: This can be either an `AggeregateRequest`, or a `Cursor`
391+
**query**: This can be either an `AggregateRequest`, or a `Cursor`
390392
391393
An `AggregateResult` object is returned. You can access the rows from
392394
its `rows` property, which will always yield the rows of the result.
@@ -401,6 +403,10 @@ def aggregate(self, query):
401403
raise ValueError("Bad query", query)
402404

403405
raw = self.execute_command(*cmd)
406+
return self._get_AggregateResult(raw, query, has_cursor)
407+
408+
def _get_AggregateResult(self, raw, query, has_cursor):
409+
# has_cursor = bool(query._cursor)
404410
if has_cursor:
405411
if isinstance(query, Cursor):
406412
query.cid = raw[1]
@@ -418,8 +424,51 @@ def aggregate(self, query):
418424
schema = None
419425
rows = raw[1:]
420426

421-
res = AggregateResult(rows, cursor, schema)
422-
return res
427+
return AggregateResult(rows, cursor, schema)
428+
429+
def profile(self, query, limited=False):
430+
"""
431+
Performs a search or aggregate command and collects performance
432+
information.
433+
434+
### Parameters
435+
436+
**query**: This can be either an `AggregateRequest`, `Query` or
437+
string.
438+
**limited**: If set to True, removes details of reader iterator.
439+
440+
"""
441+
st = time.time()
442+
cmd = [PROFILE_CMD, self.index_name, ""]
443+
if limited:
444+
cmd.append("LIMITED")
445+
cmd.append('QUERY')
446+
447+
if isinstance(query, AggregateRequest):
448+
cmd[2] = "AGGREGATE"
449+
cmd += query.build_args()
450+
elif isinstance(query, Query):
451+
cmd[2] = "SEARCH"
452+
cmd += query.get_args()
453+
elif isinstance(query, str):
454+
cmd[2] = "SEARCH"
455+
cmd.append(query)
456+
else:
457+
raise ValueError("Must provide AggregateRequest object, "
458+
"Query object or str.")
459+
460+
res = self.execute_command(*cmd)
461+
462+
if isinstance(query, AggregateRequest):
463+
result = self._get_AggregateResult(res[0], query, query._cursor)
464+
else:
465+
result = Result(res[0],
466+
not query._no_content,
467+
duration=(time.time() - st) * 1000.0,
468+
has_payload=query._with_payloads,
469+
with_scores=query._with_scores,)
470+
471+
return result, parse_to_dict(res[1])
423472

424473
def spellcheck(self, query, distance=None, include=None, exclude=None):
425474
"""

tests/test_search.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,28 @@ def test_json_with_jsonpath(client):
13131313
assert res.docs[0].id == "doc:1"
13141314
with pytest.raises(Exception):
13151315
res.docs[0].name_unsupported
1316+
1317+
1318+
@pytest.mark.redismod
1319+
def test_profile(client):
1320+
client.ft().create_index((TextField('t'),))
1321+
client.ft().client.hset('1', 't', 'hello')
1322+
client.ft().client.hset('2', 't', 'world')
1323+
1324+
# check using Query
1325+
q = Query('hello|world').no_content()
1326+
res, det = client.ft().profile(q)
1327+
assert det['Iterators profile']['Counter'] == 2.0
1328+
assert len(det['Iterators profile']['Child iterators']) == 2
1329+
assert det['Iterators profile']['Type'] == 'UNION'
1330+
assert det['Parsing time'] < 0.3
1331+
assert len(res.docs) == 2 # check also the search result
1332+
1333+
# check using AggregateRequest
1334+
req = aggregations.AggregateRequest("*").load("t")\
1335+
.apply(prefix="startswith(@t, 'hel')")
1336+
res, det = client.ft().profile(req)
1337+
assert det['Iterators profile']['Counter'] == 2.0
1338+
assert det['Iterators profile']['Type'] == 'WILDCARD'
1339+
assert det['Parsing time'] < 0.3
1340+
assert len(res.rows) == 2 # check also the search result

0 commit comments

Comments
 (0)