Skip to content
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 27 additions & 77 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,75 +82,6 @@ class Desc(SortDirection):
DIRSTRING = "DESC"


class Group:
"""
This object automatically created in the `AggregateRequest.group_by()`
"""

def __init__(self, fields, reducers):
if not reducers:
raise ValueError("Need at least one reducer")

fields = [fields] if isinstance(fields, str) else fields
reducers = [reducers] if isinstance(reducers, Reducer) else reducers

self.fields = fields
self.reducers = reducers
self.limit = Limit()

def build_args(self):
ret = ["GROUPBY", str(len(self.fields))]
ret.extend(self.fields)
for reducer in self.reducers:
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias is not None:
ret += ["AS", reducer._alias]
return ret


class Projection:
"""
This object automatically created in the `AggregateRequest.apply()`
"""

def __init__(self, projector, alias=None):
self.alias = alias
self.projector = projector

def build_args(self):
ret = ["APPLY", self.projector]
if self.alias is not None:
ret += ["AS", self.alias]

return ret


class SortBy:
"""
This object automatically created in the `AggregateRequest.sort_by()`
"""

def __init__(self, fields, max=0):
self.fields = fields
self.max = max

def build_args(self):
fields_args = []
for f in self.fields:
if isinstance(f, SortDirection):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]

ret = ["SORTBY", str(len(fields_args))]
ret.extend(fields_args)
if self.max > 0:
ret += ["MAX", str(self.max)]

return ret


class AggregateRequest:
"""
Aggregation request which can be passed to `Client.aggregate`.
Expand Down Expand Up @@ -202,9 +133,17 @@ def group_by(self, fields, *reducers):
- **reducers**: One or more reducers. Reducers may be found in the
`aggregation` module.
"""
group = Group(fields, reducers)
self._aggregateplan.extend(group.build_args())
fields = [fields] if isinstance(fields, str) else fields
reducers = [reducers] if isinstance(reducers, Reducer) else reducers

ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias is not None:
ret += ["AS", reducer._alias]

self._aggregateplan.extend(ret)
return self

def apply(self, **kwexpr):
Expand All @@ -218,8 +157,10 @@ def apply(self, **kwexpr):
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
projection = Projection(expr, alias)
self._aggregateplan.extend(projection.build_args())
ret = ["APPLY", expr]
if alias is not None:
ret += ["AS", alias]
self._aggregateplan.extend(ret)

return self

Expand Down Expand Up @@ -265,8 +206,7 @@ def limit(self, offset, num):
`sort_by()` instead.

"""
limit = Limit(offset, num)
self._limit = limit
self._limit = Limit(offset, num)
return self

def sort_by(self, *fields, **kwargs):
Expand Down Expand Up @@ -300,10 +240,20 @@ def sort_by(self, *fields, **kwargs):
if isinstance(fields, (str, SortDirection)):
fields = [fields]

fields_args = []
for f in fields:
if isinstance(f, SortDirection):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]

ret = ["SORTBY", str(len(fields_args))]
ret.extend(fields_args)
max = kwargs.get("max", 0)
sortby = SortBy(fields, max)
if max > 0:
ret += ["MAX", str(max)]

self._aggregateplan.extend(sortby.build_args())
self._aggregateplan.extend(ret)
return self

def filter(self, expressions):
Expand Down