diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 00435f626b..0fb44ab54a 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -27,9 +27,15 @@ class Reducer: NAME = None def __init__(self, *args: str) -> None: +<<<<<<< fixing_mypy_errors + self._args: tuple[str, ...] = args + self._field: Optional[str] = None + self._alias: Optional[str] = None +======= self._args = args self._field = None self._alias = None +>>>>>>> master def alias(self, alias: str) -> "Reducer": """ @@ -49,13 +55,14 @@ def alias(self, alias: str) -> "Reducer": if alias is FIELDNAME: if not self._field: raise ValueError("Cannot use FIELDNAME alias with no field") - # Chop off initial '@' - alias = self._field[1:] + else: + # Chop off initial '@' + alias = self._field[1:] self._alias = alias return self @property - def args(self) -> List[str]: + def args(self) -> tuple[str, ...]: return self._args @@ -64,7 +71,7 @@ class SortDirection: This special class is used to indicate sort direction. """ - DIRSTRING = None + DIRSTRING: Optional[str] = None def __init__(self, field: str) -> None: self.field = field @@ -104,6 +111,19 @@ def __init__(self, query: str = "*") -> None: All member methods (except `build_args()`) return the object itself, making them useful for chaining. """ +<<<<<<< fixing_mypy_errors + self._query: str = query + self._aggregateplan: List[str] = [] + self._loadfields: List[str] = [] + self._loadall: bool = False + self._max: int = 0 + self._with_schema: bool = False + self._verbatim: bool = False + self._cursor: List[str] = [] + self._dialect: int = DEFAULT_DIALECT + self._add_scores: bool = False + self._scorer: str = "TFIDF" +======= self._query = query self._aggregateplan = [] self._loadfields = [] @@ -115,6 +135,7 @@ def __init__(self, query: str = "*") -> None: self._dialect = DEFAULT_DIALECT self._add_scores = False self._scorer = "TFIDF" +>>>>>>> master def load(self, *fields: str) -> "AggregateRequest": """ @@ -133,7 +154,7 @@ def load(self, *fields: str) -> "AggregateRequest": return self def group_by( - self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + self, fields: Union[str, List[str]], *reducers: Reducer ) -> "AggregateRequest": """ Specify by which fields to group the aggregation. @@ -147,7 +168,6 @@ def group_by( `aggregation` module. """ fields = [fields] if isinstance(fields, str) else fields - reducers = [reducers] if isinstance(reducers, Reducer) else reducers ret = ["GROUPBY", str(len(fields)), *fields] for reducer in reducers: @@ -251,12 +271,10 @@ def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest": .sort_by(Desc("@paid"), max=10) ``` """ - if isinstance(fields, (str, SortDirection)): - fields = [fields] fields_args = [] for f in fields: - if isinstance(f, SortDirection): + if isinstance(f, (Asc, Desc)): fields_args += [f.field, f.DIRSTRING] else: fields_args += [f] @@ -356,7 +374,7 @@ def build_args(self) -> List[str]: ret.extend(self._loadfields) if self._dialect: - ret.extend(["DIALECT", self._dialect]) + ret.extend(["DIALECT", str(self._dialect)]) ret.extend(self._aggregateplan) @@ -393,7 +411,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None: self.cursor = cursor self.schema = schema - def __repr__(self) -> (str, str): + def __repr__(self) -> str: cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 80d9b35728..a4d3c663a2 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -464,7 +464,7 @@ def info(self): return self._parse_results(INFO_CMD, res) def get_params_args( - self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]] ): if query_params is None: return [] @@ -543,7 +543,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa def aggregate( self, query: Union[AggregateRequest, Cursor], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Optional[Dict[str, Union[str, int, float]]] = None, ): """ Issue an aggregation query. diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index a8312a2ad2..615e6d10fa 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None: self._with_scores: bool = False self._scorer: Optional[str] = None self._filters: List = list() - self._ids: Optional[List[str]] = None + self._ids: Optional[Tuple[str]] = None self._slop: int = -1 self._timeout: Optional[float] = None self._in_order: bool = False @@ -81,7 +81,7 @@ def return_field( self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields: List[str]) -> List: + def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) @@ -126,7 +126,7 @@ def summarize( def highlight( self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None - ) -> None: + ) -> "Query": """ Apply specified markup to matched term(s) within the returned field(s). @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query": self._scorer = scorer return self - def get_args(self) -> List[str]: + def get_args(self) -> List[Union[str, int, float]]: """Format the redis arguments for this query and return them.""" - args = [self._query_string] + args: List[Union[str, int, float]] = [self._query_string] args += self._get_args_tags() args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self) -> List[str]: - args = [] + def _get_args_tags(self) -> List[Union[str, int, float]]: + args: List[Union[str, int, float]] = [] if self._no_content: args.append("NOCONTENT") if self._fields: @@ -288,14 +288,14 @@ def with_scores(self) -> "Query": self._with_scores = True return self - def limit_fields(self, *fields: List[str]) -> "Query": + def limit_fields(self, *fields: str) -> "Query": """ Limit the search to specific TEXT fields only. - - **fields**: A list of strings, case sensitive field names + - **fields**: Each element should be a string, case sensitive field name from the defined schema. """ - self._fields = fields + self._fields = list(fields) return self def add_filter(self, flt: "Filter") -> "Query": @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword: str, field: str, *args: List[str]) -> None: + def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None: self.args = [keyword, field] + list(args)