@@ -96,7 +96,74 @@ def _process_ihdp_sim_data():
96
96
return T , X
97
97
98
98
99
- class StandardDGP ():
99
+ class StandardDGP :
100
+ """
101
+ A class to generate synthetic causal datasets
102
+
103
+ Parameters
104
+ ----------
105
+ n: int
106
+ Number of observations to generate
107
+
108
+ d_t: int
109
+ Dimensionality of treatment
110
+
111
+ d_y: int
112
+ Dimensionality of outcome
113
+
114
+ d_x: int
115
+ Dimensionality of features
116
+
117
+ d_z: int
118
+ Dimensionality of instrument
119
+
120
+ discrete_treatment: bool
121
+ Dimensionality of treatment
122
+
123
+ discrete_isntrument: bool
124
+ Dimensionality of instrument
125
+
126
+ squeeze_T: bool
127
+ Whether to squeeze the final T array on output
128
+
129
+ squeeze_Y: bool
130
+ Whether to squeeze the final Y array on output
131
+
132
+ nuisance_Y: func or dict
133
+ Nuisance function. Describes how the covariates affect the outcome.
134
+ If a function, this function will be used on features X to partially generate Y.
135
+ If a dict, must include 'support' and 'degree' keys.
136
+
137
+ nuisance_T: func or dict
138
+ Nuisance function. Describes how the covariates affect the treatment.
139
+ If a function, this function will be used on features X to partially generate T.
140
+ If a dict, must include 'support' and 'degree' keys.
141
+
142
+ nuisance_TZ: func or dict
143
+ Nuisance function. Describes how the instrument affects the treatment.
144
+ If a function, this function will be used on instrument Z to partially generate T.
145
+ If a dict, must include 'support' and 'degree' keys.
146
+
147
+ theta: func or dict
148
+ Describes how the features affects the treatment effect heterogenity.
149
+ If a function, this function will be used on features X to calculate treatment effect heterogenity.
150
+ If a dict, must include 'support' and 'degree' keys.
151
+
152
+ y_of_t: func or dict
153
+ Describes how the treatment affects the outcome.
154
+ If a function, this function will be used directly.
155
+ If a dict, must include 'support' and 'degree' keys.
156
+
157
+ x_eps: float
158
+ Noise parameter for feature generation
159
+
160
+ y_eps: func or dict
161
+ Noise parameter for outcome generation
162
+
163
+ t_eps: func or dict
164
+ Noise parameter for treatment generation
165
+
166
+ """
100
167
def __init__ (self ,
101
168
n = 1000 ,
102
169
d_t = 1 ,
@@ -114,7 +181,8 @@ def __init__(self,
114
181
y_of_t = None ,
115
182
x_eps = 1 ,
116
183
y_eps = 1 ,
117
- t_eps = 1
184
+ t_eps = 1 ,
185
+ random_state = None
118
186
):
119
187
self .n = n
120
188
self .d_t = d_t
@@ -132,15 +200,15 @@ def __init__(self,
132
200
else : # else must be dict
133
201
if nuisance_Y is None :
134
202
nuisance_Y = {'support' : self .d_x , 'degree' : 1 }
135
- nuisance_Y [ 'k' ] = self . d_x
203
+ assert isinstance ( nuisance_Y , dict ), f"nuisance_Y must be a callable or dict, but got { type ( nuisance_Y ) } "
136
204
self .nuisance_Y , self .nuisance_Y_coefs = self .gen_nuisance (** nuisance_Y )
137
205
138
206
if callable (nuisance_T ):
139
207
self .nuisance_T = nuisance_T
140
208
else : # else must be dict
141
209
if nuisance_T is None :
142
210
nuisance_T = {'support' : self .d_x , 'degree' : 1 }
143
- nuisance_T [ 'k' ] = self . d_x
211
+ assert isinstance ( nuisance_T , dict ), f"nuisance_T must be a callable or dict, but got { type ( nuisance_T ) } "
144
212
self .nuisance_T , self .nuisance_T_coefs = self .gen_nuisance (** nuisance_T )
145
213
146
214
if self .d_z :
@@ -149,7 +217,9 @@ def __init__(self,
149
217
else : # else must be dict
150
218
if nuisance_TZ is None :
151
219
nuisance_TZ = {'support' : self .d_z , 'degree' : 1 }
152
- nuisance_TZ ['k' ] = self .d_z
220
+ assert isinstance (
221
+ nuisance_TZ , dict ), f"nuisance_TZ must be a callable or dict, but got { type (nuisance_TZ )} "
222
+ nuisance_TZ = {** nuisance_TZ , 'k' : self .d_z }
153
223
self .nuisance_TZ , self .nuisance_TZ_coefs = self .gen_nuisance (** nuisance_TZ )
154
224
else :
155
225
self .nuisance_TZ = lambda x : 0
@@ -159,14 +229,15 @@ def __init__(self,
159
229
else : # else must be dict
160
230
if theta is None :
161
231
theta = {'support' : self .d_x , 'degree' : 1 , 'bounds' : [1 , 2 ], 'intercept' : True }
162
- theta [ 'k' ] = self . d_x
232
+ assert isinstance ( theta , dict ), f"theta must be a callable or dict, but got { type ( theta ) } "
163
233
self .theta , self .theta_coefs = self .gen_nuisance (** theta )
164
234
165
235
if callable (y_of_t ):
166
236
self .y_of_t = y_of_t
167
237
else : # else must be dict
168
238
if y_of_t is None :
169
239
y_of_t = {'support' : self .d_t , 'degree' : 1 , 'bounds' : [1 , 1 ]}
240
+ assert isinstance (y_of_t , dict ), f"y_of_t must be a callable or dict, but got { type (y_of_t )} "
170
241
y_of_t ['k' ] = self .d_t
171
242
self .y_of_t , self .y_of_t_coefs = self .gen_nuisance (** y_of_t )
172
243
@@ -199,9 +270,6 @@ def gen_T(self):
199
270
def gen_Z (self ):
200
271
if self .d_z :
201
272
if self .discrete_instrument :
202
- # prob_Z = expit(np.random.normal(size=(self.n, self.d_z)))
203
- # self.Z = np.random.binomial(1, prob_Z, size=(self.n, 1))
204
- # self.Z = np.random.binomial(1, prob_Z)
205
273
self .Z = np .random .binomial (1 , 0.5 , size = (self .n , self .d_z ))
206
274
return self .Z
207
275
@@ -224,7 +292,6 @@ def gen_nuisance(self, k=None, support=1, bounds=[-1, 1], degree=1, intercept=Fa
224
292
mask [supports ] = 1
225
293
coefs = coefs * mask
226
294
227
- # orders = np.random.randint(1, degree, k) if degree!=1 else np.ones(shape=(k,))
228
295
orders = np .ones (shape = (k ,)) * degree # enforce all to be the degree for now
229
296
230
297
if intercept :
0 commit comments