Skip to content

Fixing errors reported by mypy. #3666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Optional, Union

from redis.commands.search.dialect import DEFAULT_DIALECT

Expand Down Expand Up @@ -27,9 +27,15 @@ class Reducer:
NAME = None

def __init__(self, *args: str) -> None:
<<<<<<< fixing_mypy_errors
self._args: tuple[str, ...] = args
self._field: Optional[str] = None
self._alias: Optional[str] = None
=======
self._args = args
self._field = None
self._alias = None
>>>>>>> master

def alias(self, alias: str) -> "Reducer":
"""
Expand All @@ -49,13 +55,14 @@ def alias(self, alias: str) -> "Reducer":
if alias is FIELDNAME:
if not self._field:
raise ValueError("Cannot use FIELDNAME alias with no field")
# Chop off initial '@'
alias = self._field[1:]
else:
# Chop off initial '@'
alias = self._field[1:]
self._alias = alias
return self

@property
def args(self) -> List[str]:
def args(self) -> tuple[str, ...]:
return self._args


Expand All @@ -64,7 +71,7 @@ class SortDirection:
This special class is used to indicate sort direction.
"""

DIRSTRING = None
DIRSTRING: Optional[str] = None

def __init__(self, field: str) -> None:
self.field = field
Expand Down Expand Up @@ -104,6 +111,19 @@ def __init__(self, query: str = "*") -> None:
All member methods (except `build_args()`)
return the object itself, making them useful for chaining.
"""
<<<<<<< fixing_mypy_errors
self._query: str = query
self._aggregateplan: List[str] = []
self._loadfields: List[str] = []
self._loadall: bool = False
self._max: int = 0
self._with_schema: bool = False
self._verbatim: bool = False
self._cursor: List[str] = []
self._dialect: int = DEFAULT_DIALECT
self._add_scores: bool = False
self._scorer: str = "TFIDF"
=======
self._query = query
self._aggregateplan = []
self._loadfields = []
Expand All @@ -115,6 +135,7 @@ def __init__(self, query: str = "*") -> None:
self._dialect = DEFAULT_DIALECT
self._add_scores = False
self._scorer = "TFIDF"
>>>>>>> master

def load(self, *fields: str) -> "AggregateRequest":
"""
Expand All @@ -133,7 +154,7 @@ def load(self, *fields: str) -> "AggregateRequest":
return self

def group_by(
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
self, fields: Union[str, List[str]], *reducers: Reducer
) -> "AggregateRequest":
"""
Specify by which fields to group the aggregation.
Expand All @@ -147,7 +168,6 @@ def group_by(
`aggregation` module.
"""
fields = [fields] if isinstance(fields, str) else fields
reducers = [reducers] if isinstance(reducers, Reducer) else reducers

ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
Expand Down Expand Up @@ -251,12 +271,10 @@ def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
.sort_by(Desc("@paid"), max=10)
```
"""
if isinstance(fields, (str, SortDirection)):
fields = [fields]

fields_args = []
for f in fields:
if isinstance(f, SortDirection):
if isinstance(f, (Asc, Desc)):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]
Expand Down Expand Up @@ -356,7 +374,7 @@ def build_args(self) -> List[str]:
ret.extend(self._loadfields)

if self._dialect:
ret.extend(["DIALECT", self._dialect])
ret.extend(["DIALECT", str(self._dialect)])

ret.extend(self._aggregateplan)

Expand Down Expand Up @@ -393,7 +411,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
self.cursor = cursor
self.schema = schema

def __repr__(self) -> (str, str):
def __repr__(self) -> str:
cid = self.cursor.cid if self.cursor else -1
return (
f"<{self.__class__.__name__} at 0x{id(self):x} "
Expand Down
4 changes: 2 additions & 2 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def info(self):
return self._parse_results(INFO_CMD, res)

def get_params_args(
self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
):
if query_params is None:
return []
Expand Down Expand Up @@ -543,7 +543,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
def aggregate(
self,
query: Union[AggregateRequest, Cursor],
query_params: Dict[str, Union[str, int, float]] = None,
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
):
"""
Issue an aggregation query.
Expand Down
24 changes: 12 additions & 12 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from redis.commands.search.dialect import DEFAULT_DIALECT

Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None:
self._with_scores: bool = False
self._scorer: Optional[str] = None
self._filters: List = list()
self._ids: Optional[List[str]] = None
self._ids: Optional[Tuple[str]] = None
self._slop: int = -1
self._timeout: Optional[float] = None
self._in_order: bool = False
Expand Down Expand Up @@ -81,7 +81,7 @@ def return_field(
self._return_fields += ("AS", as_field)
return self

def _mk_field_list(self, fields: List[str]) -> List:
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
if not fields:
return []
return [fields] if isinstance(fields, str) else list(fields)
Expand Down Expand Up @@ -126,7 +126,7 @@ def summarize(

def highlight(
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
) -> None:
) -> "Query":
"""
Apply specified markup to matched term(s) within the returned field(s).

Expand Down Expand Up @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query":
self._scorer = scorer
return self

def get_args(self) -> List[str]:
def get_args(self) -> List[Union[str, int, float]]:
"""Format the redis arguments for this query and return them."""
args = [self._query_string]
args: List[Union[str, int, float]] = [self._query_string]
args += self._get_args_tags()
args += self._summarize_fields + self._highlight_fields
args += ["LIMIT", self._offset, self._num]
return args

def _get_args_tags(self) -> List[str]:
args = []
def _get_args_tags(self) -> List[Union[str, int, float]]:
args: List[Union[str, int, float]] = []
if self._no_content:
args.append("NOCONTENT")
if self._fields:
Expand Down Expand Up @@ -288,14 +288,14 @@ def with_scores(self) -> "Query":
self._with_scores = True
return self

def limit_fields(self, *fields: List[str]) -> "Query":
def limit_fields(self, *fields: str) -> "Query":
"""
Limit the search to specific TEXT fields only.

- **fields**: A list of strings, case sensitive field names
- **fields**: Each element should be a string, case sensitive field name
from the defined schema.
"""
self._fields = fields
self._fields = list(fields)
return self

def add_filter(self, flt: "Filter") -> "Query":
Expand Down Expand Up @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query":


class Filter:
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
self.args = [keyword, field] + list(args)


Expand Down
Loading