diff --git a/dask_match/expr.py b/dask_match/expr.py index 30a518612..90a5c1581 100644 --- a/dask_match/expr.py +++ b/dask_match/expr.py @@ -198,6 +198,30 @@ def index(self): def size(self): return Size(self) + def _statistics(self): + return {} + + def statistics(self) -> dict: + """Known quantities of an expression, like length or min/max + + To define this on a class create a `._statistics` method that returns a + dictionary of new statistics known by that class. If nothing is known it + is ok to return None. Superclasses will also be consulted. + + Examples + -------- + >>> df.statistics() + {"length": 1000000} + """ + out = {} + for typ in type(self).mro()[::-1]: + if not issubclass(typ, Expr): + continue + d = typ._statistics(self) # TODO: maybe this should be cached + if d: + out.update(d) # TODO: this is fragile + return out + def __getitem__(self, other): if isinstance(other, Expr): return Filter(self, other) # df[df.x > 1] @@ -468,7 +492,10 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - pass + def _statistics(self): + for dep in self.dependencies(): + if dep.npartitions == self.npartitions and "length" in dep.statistics(): + return {"length": dep.statistics()["length"]} class AsType(Elemwise): diff --git a/dask_match/io/io.py b/dask_match/io/io.py index 269c79313..65f39e668 100644 --- a/dask_match/io/io.py +++ b/dask_match/io/io.py @@ -58,4 +58,7 @@ def _task(self, index: int | None = None): def __str__(self): return "df" + def _statistics(self): + return {"length": len(self.frame)} + __repr__ = __str__ diff --git a/dask_match/tests/test_collection.py b/dask_match/tests/test_collection.py index 7606adaaa..ca9d8e053 100644 --- a/dask_match/tests/test_collection.py +++ b/dask_match/tests/test_collection.py @@ -288,3 +288,8 @@ def test_simple_graphs(df): graph = expr.__dask_graph__() assert graph[(expr._name, 0)] == (operator.add, (df.expr._name, 0), 1) + + +def test_statistics(df, pdf): + assert (df + 1).statistics()["length"] == len(pdf) + assert df[df.x > 5].statistics().get("length") is None