Skip to content

Commit 5961cbf

Browse files
committed
add docstring, some arg checks
1 parent 2d8a94d commit 5961cbf

File tree

1 file changed

+77
-10
lines changed

1 file changed

+77
-10
lines changed

econml/data/dgps.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,74 @@ def _process_ihdp_sim_data():
9696
return T, X
9797

9898

99-
class StandardDGP():
99+
class StandardDGP:
100+
"""
101+
A class to generate synthetic causal datasets
102+
103+
Parameters
104+
----------
105+
n: int
106+
Number of observations to generate
107+
108+
d_t: int
109+
Dimensionality of treatment
110+
111+
d_y: int
112+
Dimensionality of outcome
113+
114+
d_x: int
115+
Dimensionality of features
116+
117+
d_z: int
118+
Dimensionality of instrument
119+
120+
discrete_treatment: bool
121+
Dimensionality of treatment
122+
123+
discrete_isntrument: bool
124+
Dimensionality of instrument
125+
126+
squeeze_T: bool
127+
Whether to squeeze the final T array on output
128+
129+
squeeze_Y: bool
130+
Whether to squeeze the final Y array on output
131+
132+
nuisance_Y: func or dict
133+
Nuisance function. Describes how the covariates affect the outcome.
134+
If a function, this function will be used on features X to partially generate Y.
135+
If a dict, must include 'support' and 'degree' keys.
136+
137+
nuisance_T: func or dict
138+
Nuisance function. Describes how the covariates affect the treatment.
139+
If a function, this function will be used on features X to partially generate T.
140+
If a dict, must include 'support' and 'degree' keys.
141+
142+
nuisance_TZ: func or dict
143+
Nuisance function. Describes how the instrument affects the treatment.
144+
If a function, this function will be used on instrument Z to partially generate T.
145+
If a dict, must include 'support' and 'degree' keys.
146+
147+
theta: func or dict
148+
Describes how the features affects the treatment effect heterogenity.
149+
If a function, this function will be used on features X to calculate treatment effect heterogenity.
150+
If a dict, must include 'support' and 'degree' keys.
151+
152+
y_of_t: func or dict
153+
Describes how the treatment affects the outcome.
154+
If a function, this function will be used directly.
155+
If a dict, must include 'support' and 'degree' keys.
156+
157+
x_eps: float
158+
Noise parameter for feature generation
159+
160+
y_eps: func or dict
161+
Noise parameter for outcome generation
162+
163+
t_eps: func or dict
164+
Noise parameter for treatment generation
165+
166+
"""
100167
def __init__(self,
101168
n=1000,
102169
d_t=1,
@@ -114,7 +181,8 @@ def __init__(self,
114181
y_of_t=None,
115182
x_eps=1,
116183
y_eps=1,
117-
t_eps=1
184+
t_eps=1,
185+
random_state=None
118186
):
119187
self.n = n
120188
self.d_t = d_t
@@ -132,15 +200,15 @@ def __init__(self,
132200
else: # else must be dict
133201
if nuisance_Y is None:
134202
nuisance_Y = {'support': self.d_x, 'degree': 1}
135-
nuisance_Y['k'] = self.d_x
203+
assert isinstance(nuisance_Y, dict), f"nuisance_Y must be a callable or dict, but got {type(nuisance_Y)}"
136204
self.nuisance_Y, self.nuisance_Y_coefs = self.gen_nuisance(**nuisance_Y)
137205

138206
if callable(nuisance_T):
139207
self.nuisance_T = nuisance_T
140208
else: # else must be dict
141209
if nuisance_T is None:
142210
nuisance_T = {'support': self.d_x, 'degree': 1}
143-
nuisance_T['k'] = self.d_x
211+
assert isinstance(nuisance_T, dict), f"nuisance_T must be a callable or dict, but got {type(nuisance_T)}"
144212
self.nuisance_T, self.nuisance_T_coefs = self.gen_nuisance(**nuisance_T)
145213

146214
if self.d_z:
@@ -149,7 +217,9 @@ def __init__(self,
149217
else: # else must be dict
150218
if nuisance_TZ is None:
151219
nuisance_TZ = {'support': self.d_z, 'degree': 1}
152-
nuisance_TZ['k'] = self.d_z
220+
assert isinstance(
221+
nuisance_TZ, dict), f"nuisance_TZ must be a callable or dict, but got {type(nuisance_TZ)}"
222+
nuisance_TZ = {**nuisance_TZ, 'k': self.d_z}
153223
self.nuisance_TZ, self.nuisance_TZ_coefs = self.gen_nuisance(**nuisance_TZ)
154224
else:
155225
self.nuisance_TZ = lambda x: 0
@@ -159,14 +229,15 @@ def __init__(self,
159229
else: # else must be dict
160230
if theta is None:
161231
theta = {'support': self.d_x, 'degree': 1, 'bounds': [1, 2], 'intercept': True}
162-
theta['k'] = self.d_x
232+
assert isinstance(theta, dict), f"theta must be a callable or dict, but got {type(theta)}"
163233
self.theta, self.theta_coefs = self.gen_nuisance(**theta)
164234

165235
if callable(y_of_t):
166236
self.y_of_t = y_of_t
167237
else: # else must be dict
168238
if y_of_t is None:
169239
y_of_t = {'support': self.d_t, 'degree': 1, 'bounds': [1, 1]}
240+
assert isinstance(y_of_t, dict), f"y_of_t must be a callable or dict, but got {type(y_of_t)}"
170241
y_of_t['k'] = self.d_t
171242
self.y_of_t, self.y_of_t_coefs = self.gen_nuisance(**y_of_t)
172243

@@ -199,9 +270,6 @@ def gen_T(self):
199270
def gen_Z(self):
200271
if self.d_z:
201272
if self.discrete_instrument:
202-
# prob_Z = expit(np.random.normal(size=(self.n, self.d_z)))
203-
# self.Z = np.random.binomial(1, prob_Z, size=(self.n, 1))
204-
# self.Z = np.random.binomial(1, prob_Z)
205273
self.Z = np.random.binomial(1, 0.5, size=(self.n, self.d_z))
206274
return self.Z
207275

@@ -224,7 +292,6 @@ def gen_nuisance(self, k=None, support=1, bounds=[-1, 1], degree=1, intercept=Fa
224292
mask[supports] = 1
225293
coefs = coefs * mask
226294

227-
# orders = np.random.randint(1, degree, k) if degree!=1 else np.ones(shape=(k,))
228295
orders = np.ones(shape=(k,)) * degree # enforce all to be the degree for now
229296

230297
if intercept:

0 commit comments

Comments
 (0)