@@ -271,10 +271,8 @@ def _logp(
271
271
272
272
273
273
@_logp .register (Elemwise )
274
- def logp_elemwise (op , * args , ** kwargs ):
275
- if hasattr (op , "scalar_op" ):
276
- return _logp (op .scalar_op , * args , ** kwargs )
277
- raise NotImplementedError
274
+ def elemwise_logp (op , * args , ** kwargs ):
275
+ return _logp (op .scalar_op , * args , ** kwargs )
278
276
279
277
280
278
# TODO: Implement DimShuffle logp?
@@ -287,14 +285,17 @@ def logp_elemwise(op, *args, **kwargs):
287
285
# raise NotImplementedError
288
286
289
287
290
- def find_rv_branch ( inputs ):
291
- """
292
- Helper function to find which input branch(es) contain unregistered random variables
293
- """
294
- rv_branch = []
295
- no_rv_branch = []
288
+ @ _logp . register ( Add )
289
+ @ _logp . register ( Mul )
290
+ def linear_logp ( op , var , rvs_to_values , * linear_inputs , ** kwargs ):
291
+
292
+ if len ( linear_inputs ) != 2 :
293
+ raise ValueError ( f"Expected 2 inputs but got: { len ( linear_inputs ) } " )
296
294
297
- for inp in inputs :
295
+ # Find base_rv and constant inputs
296
+ base_rv = []
297
+ constant = []
298
+ for inp in linear_inputs :
298
299
res_ancestors = list (walk_model ((inp ,), walk_past_rvs = True ))
299
300
# unregistered variables do not contain a value_var tag
300
301
res_unregistered_ancestors = [
@@ -305,94 +306,47 @@ def find_rv_branch(inputs):
305
306
and not getattr (v .tag , "value_var" , False )
306
307
]
307
308
if res_unregistered_ancestors :
308
- rv_branch .append (inp )
309
+ base_rv .append (inp )
309
310
else :
310
- no_rv_branch .append (inp )
311
-
312
- return rv_branch , no_rv_branch
313
-
314
-
315
- @_logp .register (Add )
316
- def add_logp (op , var , rvs_to_values , * add_inputs , ** kwargs ):
317
-
318
- if len (add_inputs ) != 2 :
319
- raise ValueError (f"Expected 2 inputs but got: { len (add_inputs )} " )
320
-
321
- base_rv , loc = find_rv_branch (add_inputs )
311
+ constant .append (inp )
322
312
323
313
if len (base_rv ) != 1 :
324
314
raise NotImplementedError (
325
- f"Logp of addition requires one branch with an unregistered RandomVariable but got { len (base_rv )} "
315
+ f"Logp of linear transform requires one branch with an unregistered RandomVariable but got { len (base_rv )} "
326
316
)
327
317
328
- var_value = rvs_to_values .get (var , var )
329
- loc = loc [0 ]
330
318
base_rv = base_rv [0 ]
331
- base_value = base_rv .type ()
332
-
333
- logp_base_rv = logpt (base_rv , {base_rv : base_value }, ** kwargs )
334
- fgraph = FunctionGraph (
335
- [i for i in graph_inputs ((logp_base_rv ,)) if not isinstance (i , Constant )],
336
- [logp_base_rv ],
337
- clone = False ,
338
- )
339
- fgraph .replace (base_value , var_value - loc , import_missing = True )
340
- logp_add_rv = fgraph .outputs [0 ]
341
-
342
- # Replace rvs in graph
343
- # TODO: This shouldn't be here
344
- (logp_add_rv ,), _ = rvs_to_value_vars (
345
- (logp_add_rv ,),
346
- apply_transforms = True , # Change this
347
- initial_replacements = None ,
348
- )
349
-
350
- logp_add_rv .name = f"__logp_{ var .name } "
351
-
352
- return logp_add_rv
353
-
354
-
355
- @_logp .register (Mul )
356
- def mul_logp (op , var , rvs_to_values , * mul_inputs , ** kwargs ):
357
-
358
- if len (mul_inputs ) != 2 :
359
- raise ValueError (f"Expected 2 inputs but got: { len (mul_inputs )} " )
360
-
361
- base_rv , scale = find_rv_branch (mul_inputs )
362
-
363
- if len (base_rv ) != 1 :
364
- raise NotImplementedError (
365
- f"Logp of product requires one branch with an unregistered RandomVariable but got { len (base_rv )} "
366
- )
367
-
319
+ constant = constant [0 ]
368
320
var_value = rvs_to_values .get (var , var )
369
- scale = scale [0 ]
370
- base_rv = base_rv [0 ]
371
- base_value = base_rv .type ()
372
321
322
+ # Get logp of base_rv
323
+ base_value = base_rv .type ()
373
324
logp_base_rv = logpt (base_rv , {base_rv : base_value }, ** kwargs )
374
325
fgraph = FunctionGraph (
375
326
[i for i in graph_inputs ((logp_base_rv ,)) if not isinstance (i , Constant )],
376
- [logp_base_rv ],
327
+ outputs = [logp_base_rv ],
377
328
clone = False ,
378
329
)
379
330
380
- # TODO: This is not correct for discrete variables
381
- # TODO: Undefined behavior for scale = 0
382
- fgraph .replace (base_value , var_value / scale , import_missing = True )
383
- logp_mul_rv = fgraph .outputs [0 ] - at .log (at .abs_ (scale ))
331
+ # Transform base_rv and apply jacobian correction (for continuous rvs)
332
+ if isinstance (op , Add ):
333
+ fgraph .replace (base_value , var_value - constant , import_missing = True )
334
+ logp_linear_rv = fgraph .outputs [0 ]
335
+ elif isinstance (op , Mul ):
336
+ fgraph .replace (base_value , var_value / constant , import_missing = True )
337
+ logp_linear_rv = fgraph .outputs [0 ]
338
+ if "float" in base_rv .dtype :
339
+ logp_linear_rv -= at .log (at .abs_ (constant ))
384
340
385
341
# Replace rvs in graph
386
- # TODO: This shouldn't be here
387
- (logp_mul_rv ,), _ = rvs_to_value_vars (
388
- (logp_mul_rv ,),
389
- apply_transforms = True , # Change this
342
+ (logp_linear_rv ,), _ = rvs_to_value_vars (
343
+ (logp_linear_rv ,),
344
+ apply_transforms = kwargs .get ("transformed" , True ),
390
345
initial_replacements = None ,
391
346
)
392
347
393
- logp_mul_rv .name = f"__logp_{ var .name } "
394
-
395
- return logp_mul_rv
348
+ logp_linear_rv .name = f"__logp_{ var .name } "
349
+ return logp_linear_rv
396
350
397
351
398
352
def convert_indices (indices , entry ):
0 commit comments