Skip to content

Commit 1caae1d

Browse files
typing of aggregation responses
1 parent 1d532b5 commit 1caae1d

File tree

7 files changed

+1904
-560
lines changed

7 files changed

+1904
-560
lines changed

elasticsearch_dsl/response/__init__.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Optional,
2727
Sequence,
2828
Tuple,
29+
TypedDict,
2930
Union,
3031
cast,
3132
)
@@ -195,21 +196,100 @@ def search_after(self) -> "SearchBase[_R]":
195196
return self._search.extra(search_after=self.hits[-1].meta.sort) # type: ignore
196197

197198

199+
_Aggregate = Union[
200+
"types.CardinalityAggregate",
201+
"types.HdrPercentilesAggregate",
202+
"types.HdrPercentileRanksAggregate",
203+
"types.TDigestPercentilesAggregate",
204+
"types.TDigestPercentileRanksAggregate",
205+
"types.PercentilesBucketAggregate",
206+
"types.MedianAbsoluteDeviationAggregate",
207+
"types.MinAggregate",
208+
"types.MaxAggregate",
209+
"types.SumAggregate",
210+
"types.AvgAggregate",
211+
"types.WeightedAvgAggregate",
212+
"types.ValueCountAggregate",
213+
"types.SimpleValueAggregate",
214+
"types.DerivativeAggregate",
215+
"types.BucketMetricValueAggregate",
216+
"types.StatsAggregate",
217+
"types.StatsBucketAggregate",
218+
"types.ExtendedStatsAggregate",
219+
"types.ExtendedStatsBucketAggregate",
220+
"types.GeoBoundsAggregate",
221+
"types.GeoCentroidAggregate",
222+
"types.HistogramAggregate",
223+
"types.DateHistogramAggregate",
224+
"types.AutoDateHistogramAggregate",
225+
"types.VariableWidthHistogramAggregate",
226+
"types.StringTermsAggregate",
227+
"types.LongTermsAggregate",
228+
"types.DoubleTermsAggregate",
229+
"types.UnmappedTermsAggregate",
230+
"types.LongRareTermsAggregate",
231+
"types.StringRareTermsAggregate",
232+
"types.UnmappedRareTermsAggregate",
233+
"types.MultiTermsAggregate",
234+
"types.MissingAggregate",
235+
"types.NestedAggregate",
236+
"types.ReverseNestedAggregate",
237+
"types.GlobalAggregate",
238+
"types.FilterAggregate",
239+
"types.ChildrenAggregate",
240+
"types.ParentAggregate",
241+
"types.SamplerAggregate",
242+
"types.UnmappedSamplerAggregate",
243+
"types.GeoHashGridAggregate",
244+
"types.GeoTileGridAggregate",
245+
"types.GeoHexGridAggregate",
246+
"types.RangeAggregate",
247+
"types.DateRangeAggregate",
248+
"types.GeoDistanceAggregate",
249+
"types.IpRangeAggregate",
250+
"types.IpPrefixAggregate",
251+
"types.FiltersAggregate",
252+
"types.AdjacencyMatrixAggregate",
253+
"types.SignificantLongTermsAggregate",
254+
"types.SignificantStringTermsAggregate",
255+
"types.UnmappedSignificantTermsAggregate",
256+
"types.CompositeAggregate",
257+
"types.FrequentItemSetsAggregate",
258+
"types.TimeSeriesAggregate",
259+
"types.ScriptedMetricAggregate",
260+
"types.TopHitsAggregate",
261+
"types.InferenceAggregate",
262+
"types.StringStatsAggregate",
263+
"types.BoxPlotAggregate",
264+
"types.TopMetricsAggregate",
265+
"types.TTestAggregate",
266+
"types.RateAggregate",
267+
"types.CumulativeCardinalityAggregate",
268+
"types.MatrixStatsAggregate",
269+
"types.GeoLineAggregate",
270+
]
271+
_AggResponseMeta = TypedDict(
272+
"_AggResponseMeta", {"search": "Request[_R]", "aggs": Mapping[str, _Aggregate]}
273+
)
274+
275+
198276
class AggResponse(AttrDict[Any], Generic[_R]):
199277
_meta: Dict[str, Any]
200278

201279
def __init__(self, aggs: "Agg[_R]", search: "Request[_R]", data: Dict[str, Any]):
202280
super(AttrDict, self).__setattr__("_meta", {"search": search, "aggs": aggs})
203281
super().__init__(data)
204282

205-
def __getitem__(self, attr_name: str) -> Any:
283+
def __getitem__(self, attr_name: str) -> _Aggregate:
206284
if attr_name in self._meta["aggs"]:
207285
# don't do self._meta['aggs'][attr_name] to avoid copying
208286
agg = self._meta["aggs"].aggs[attr_name]
209-
return agg.result(self._meta["search"], self._d_[attr_name])
210-
return super().__getitem__(attr_name)
287+
return cast(
288+
_Aggregate, agg.result(self._meta["search"], self._d_[attr_name])
289+
)
290+
return super().__getitem__(attr_name) # type: ignore
211291

212-
def __iter__(self) -> Iterator["Agg"]: # type: ignore[override]
292+
def __iter__(self) -> Iterator[_Aggregate]: # type: ignore[override]
213293
for name in self._meta["aggs"]:
214294
yield self[name]
215295

0 commit comments

Comments
 (0)