From 16c003f2d1c47973a86b218fd8552969406c9437 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Tue, 16 Sep 2025 18:36:08 +0200 Subject: [PATCH 1/6] fix --- dataframely/columns/_base.py | 10 +++++++++- tests/schema/test_serialization.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index 027109d..59ab7ef 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -319,7 +319,7 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: param: ( _check_to_expr(getattr(self, param), expr) if param == "check" - else getattr(self, param) + else _series_to_list(getattr(self, param)) ) for param in inspect.signature(self.__class__.__init__).parameters if param not in ("self", "alias") @@ -443,6 +443,8 @@ def _check_to_expr(check: Check | None, expr: pl.Expr) -> Any | None: return [c(expr) for c in check] case Mapping(): return {key: c(expr) for key, c in check.items()} + case pl.Series(): + return None case _ if callable(check): return check(expr) @@ -459,3 +461,9 @@ def _check_from_expr(value: Any) -> Check | None: return lambda _: value case _: # pragma: no cover raise ValueError(f"Invalid type for check: {type(value)}") + + +def _series_to_list(value: Any) -> Any: + if isinstance(value, pl.Series): + return value.to_list() + return value diff --git a/tests/schema/test_serialization.py b/tests/schema/test_serialization.py index 682be0b..46e1204 100644 --- a/tests/schema/test_serialization.py +++ b/tests/schema/test_serialization.py @@ -40,6 +40,7 @@ def test_simple_serialization() -> None: create_schema("test", {"a": dy.Int64(check=[lambda expr: expr > 5])}), create_schema("test", {"a": dy.Int64(check={"x": lambda expr: expr > 5})}), create_schema("test", {"a": dy.Int64(alias="foo")}), + create_schema("test", {"a": dy.Enum(["a"])}), create_schema("test", {"a": dy.Decimal(scale=2, min=Decimal("1.5"))}), create_schema("test", {"a": dy.Date(min=dt.date(2020, 1, 1))}), create_schema("test", {"a": dy.Datetime(min=dt.datetime(2020, 1, 1))}), From 24f48b67c34f2a5c78cf0d25380b96a20c229045 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Tue, 16 Sep 2025 18:59:50 +0200 Subject: [PATCH 2/6] Fix repr --- dataframely/columns/_base.py | 5 +++-- tests/schema/test_repr.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index 59ab7ef..bc1d8ae 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -396,7 +396,7 @@ def _attributes_match( def __repr__(self) -> str: parts = [ - f"{attribute}={repr(getattr(self, attribute))}" + f"{attribute}={repr(_series_to_list(getattr(self, attribute)))}" for attribute, param_details in inspect.signature( self.__class__.__init__ ).parameters.items() @@ -404,7 +404,7 @@ def __repr__(self) -> str: not in ["self", "alias"] # alias is always equal to the column name here and not ( # Do not include attributes that are set to their default value - getattr(self, attribute) == param_details.default + _series_to_list(getattr(self, attribute)) == param_details.default ) ] return f"{self.__class__.__name__}({', '.join(parts)})" @@ -464,6 +464,7 @@ def _check_from_expr(value: Any) -> Check | None: def _series_to_list(value: Any) -> Any: + """If passed a `pl.Series` value, converts it to a list.""" if isinstance(value, pl.Series): return value.to_list() return value diff --git a/tests/schema/test_repr.py b/tests/schema/test_repr.py index fff11d7..63958da 100644 --- a/tests/schema/test_repr.py +++ b/tests/schema/test_repr.py @@ -52,3 +52,14 @@ def test_repr_with_rules() -> None: - "my_rule": [(col("a")) < (dyn int: 100)] - "my_group_rule": [(col("a").sum()) > (dyn int: 50)] grouped by ['a'] """) + + +def test_repr_enum() -> None: + class SchemaNoRules(dy.Schema): + a = dy.Enum(["a"]) + + assert repr(SchemaNoRules) == textwrap.dedent("""\ + [Schema "SchemaNoRules"] + Columns: + - "a": Enum(categories=['a'], nullable=True) + """) From 4a9176319310c65b790401a233061393c7ca4fa7 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Tue, 16 Sep 2025 19:03:04 +0200 Subject: [PATCH 3/6] Cleanup --- dataframely/columns/_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index bc1d8ae..ee0ef4a 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -443,8 +443,6 @@ def _check_to_expr(check: Check | None, expr: pl.Expr) -> Any | None: return [c(expr) for c in check] case Mapping(): return {key: c(expr) for key, c in check.items()} - case pl.Series(): - return None case _ if callable(check): return check(expr) From 2e68141a66deed5d20608b38dbfc9284441eb499 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Tue, 23 Sep 2025 10:11:11 +0200 Subject: [PATCH 4/6] refactor: Store Enum categories as `list` instead of `pl.Series` --- dataframely/columns/_base.py | 13 ------------- dataframely/columns/enum.py | 12 ++++-------- docs/_api/dataframely.columns.rst | 16 ++++++++++++++++ docs/_api/dataframely.testing.rst | 8 ++++++++ 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index 027109d..c8fb52e 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -381,15 +381,6 @@ def _attributes_match( if name == "check": return _compare_checks(lhs, rhs, column_expr) - lhs_is_series = isinstance(lhs, pl.Series) - rhs_is_series = isinstance(rhs, pl.Series) - - if lhs_is_series != rhs_is_series: - return False - - if lhs_is_series and rhs_is_series: - return _compare_series(lhs, rhs) - return lhs == rhs # -------------------------------- DUNDER METHODS -------------------------------- # @@ -413,10 +404,6 @@ def __str__(self) -> str: return self.__class__.__name__.lower() -def _compare_series(lhs: pl.Series, rhs: pl.Series) -> bool: - return (len(lhs) == len(rhs)) and lhs.equals(rhs) - - def _compare_checks(lhs: Check | None, rhs: Check | None, expr: pl.Expr) -> bool: match (lhs, rhs): case (None, None): diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index c4b4f9e..665e6ad 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -67,12 +67,8 @@ def __init__( metadata=metadata, ) if isclass(categories) and issubclass(categories, enum.Enum): - categories = pl.Series( - values=[getattr(v, "value", v) for v in categories.__members__.values()] - ) - elif not isinstance(categories, pl.Series): - categories = pl.Series(values=categories) - self.categories = categories + categories = (item.value for item in categories) + self.categories = list(categories) @property def dtype(self) -> pl.DataType: @@ -81,7 +77,7 @@ def dtype(self) -> pl.DataType: def validate_dtype(self, dtype: PolarsDataType) -> bool: if not isinstance(dtype, pl.Enum): return False - return self.categories.equals(dtype.categories) + return self.categories == dtype.categories.to_list() def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: category_lengths = [len(c) for c in self.categories] @@ -102,6 +98,6 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_choice( n, - choices=self.categories.to_list(), + choices=self.categories, null_probability=self._null_probability, ).cast(self.dtype) diff --git a/docs/_api/dataframely.columns.rst b/docs/_api/dataframely.columns.rst index e6fa538..01f0752 100644 --- a/docs/_api/dataframely.columns.rst +++ b/docs/_api/dataframely.columns.rst @@ -25,6 +25,14 @@ dataframely.columns.array module :show-inheritance: :undoc-members: +dataframely.columns.binary module +--------------------------------- + +.. automodule:: dataframely.columns.binary + :members: + :show-inheritance: + :undoc-members: + dataframely.columns.bool module ------------------------------- @@ -33,6 +41,14 @@ dataframely.columns.bool module :show-inheritance: :undoc-members: +dataframely.columns.categorical module +-------------------------------------- + +.. automodule:: dataframely.columns.categorical + :members: + :show-inheritance: + :undoc-members: + dataframely.columns.datetime module ----------------------------------- diff --git a/docs/_api/dataframely.testing.rst b/docs/_api/dataframely.testing.rst index 4ed627d..8c9509c 100644 --- a/docs/_api/dataframely.testing.rst +++ b/docs/_api/dataframely.testing.rst @@ -41,6 +41,14 @@ dataframely.testing.rules module :show-inheritance: :undoc-members: +dataframely.testing.storage module +---------------------------------- + +.. automodule:: dataframely.testing.storage + :members: + :show-inheritance: + :undoc-members: + dataframely.testing.typing module --------------------------------- From 69a52b1210229c3b8565bc4cb1e5e44e3339c527 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Tue, 23 Sep 2025 10:14:18 +0200 Subject: [PATCH 5/6] roll back accidental change --- docs/_api/dataframely.columns.rst | 16 ---------------- docs/_api/dataframely.testing.rst | 8 -------- 2 files changed, 24 deletions(-) diff --git a/docs/_api/dataframely.columns.rst b/docs/_api/dataframely.columns.rst index 01f0752..e6fa538 100644 --- a/docs/_api/dataframely.columns.rst +++ b/docs/_api/dataframely.columns.rst @@ -25,14 +25,6 @@ dataframely.columns.array module :show-inheritance: :undoc-members: -dataframely.columns.binary module ---------------------------------- - -.. automodule:: dataframely.columns.binary - :members: - :show-inheritance: - :undoc-members: - dataframely.columns.bool module ------------------------------- @@ -41,14 +33,6 @@ dataframely.columns.bool module :show-inheritance: :undoc-members: -dataframely.columns.categorical module --------------------------------------- - -.. automodule:: dataframely.columns.categorical - :members: - :show-inheritance: - :undoc-members: - dataframely.columns.datetime module ----------------------------------- diff --git a/docs/_api/dataframely.testing.rst b/docs/_api/dataframely.testing.rst index 8c9509c..4ed627d 100644 --- a/docs/_api/dataframely.testing.rst +++ b/docs/_api/dataframely.testing.rst @@ -41,14 +41,6 @@ dataframely.testing.rules module :show-inheritance: :undoc-members: -dataframely.testing.storage module ----------------------------------- - -.. automodule:: dataframely.testing.storage - :members: - :show-inheritance: - :undoc-members: - dataframely.testing.typing module --------------------------------- From bb4fc17c7b65291681f806473aa3331e82e034f2 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Tue, 23 Sep 2025 10:23:11 +0200 Subject: [PATCH 6/6] remove list conversion --- dataframely/columns/_base.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index e761a7c..c8fb52e 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -319,7 +319,7 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: param: ( _check_to_expr(getattr(self, param), expr) if param == "check" - else _series_to_list(getattr(self, param)) + else getattr(self, param) ) for param in inspect.signature(self.__class__.__init__).parameters if param not in ("self", "alias") @@ -387,7 +387,7 @@ def _attributes_match( def __repr__(self) -> str: parts = [ - f"{attribute}={repr(_series_to_list(getattr(self, attribute)))}" + f"{attribute}={repr(getattr(self, attribute))}" for attribute, param_details in inspect.signature( self.__class__.__init__ ).parameters.items() @@ -395,7 +395,7 @@ def __repr__(self) -> str: not in ["self", "alias"] # alias is always equal to the column name here and not ( # Do not include attributes that are set to their default value - _series_to_list(getattr(self, attribute)) == param_details.default + getattr(self, attribute) == param_details.default ) ] return f"{self.__class__.__name__}({', '.join(parts)})" @@ -446,10 +446,3 @@ def _check_from_expr(value: Any) -> Check | None: return lambda _: value case _: # pragma: no cover raise ValueError(f"Invalid type for check: {type(value)}") - - -def _series_to_list(value: Any) -> Any: - """If passed a `pl.Series` value, converts it to a list.""" - if isinstance(value, pl.Series): - return value.to_list() - return value