Skip to content

Commit a0c49e2

Browse files
committed
Convert pre-treatment references to class instances
1 parent 09829c0 commit a0c49e2

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pensieve/statistics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ def transform(self, df: DataFrame, metric: str) -> "StatisticResultCollection":
107107
@classmethod
108108
def from_dict(cls, config_dict: Dict[str, Any]):
109109
"""Create a class instance with the specified config parameters."""
110+
if "pre_treatments" in config_dict:
111+
# convert pre-treatments to class instances
112+
pre_treatments = []
113+
114+
for pre_treatment_name in config_dict["pre_treatments"]:
115+
found = False
116+
for pre_treatment in PreTreatment.__subclasses__():
117+
if pre_treatment.name() == pre_treatment_name:
118+
pre_treatments.append(pre_treatment)
119+
found = True
120+
121+
if not found:
122+
raise ValueError(f"Could not find pre-treatment {pre_treatment_name}")
123+
124+
config_dict["pre_treatments"] = pre_treatments
125+
110126
return cls(**config_dict)
111127

112128

pensieve/tests/integration/test_analysis_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def setup(self, client):
7777
client.query(source_file.read()).result()
7878
yield
7979

80-
# client.delete_dataset(self.test_dataset, delete_contents=True, not_found_ok=True)
80+
client.delete_dataset(self.test_dataset, delete_contents=True, not_found_ok=True)
8181

8282
def test_metrics(self, client):
8383
experiment = Experiment(

0 commit comments

Comments
 (0)