From f768489a9d3ab1f0462eab9adc1e848249c6b52c Mon Sep 17 00:00:00 2001 From: Jialei Chen Date: Mon, 25 Aug 2025 14:54:57 -0700 Subject: [PATCH 1/2] Refactor _jax_forward and _jax_backward functions to avoid cache cache collisions --- torchax/torchax/interop.py | 62 ++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index a87efe9dfe7..37e766e2302 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -237,6 +237,38 @@ def j2t_autograd(fn, call_jax=call_jax): the PyTorch autograd framework by saving the residuals into the context object. """ + # NOTE(qihqi): This function cannot be inlined from the callsite + # Becuase if it does, then it won't hit the compilation cache for + # call_jax. Call jax uses functions' id as key. + # It is nested inside j2t_autograd to ensure it gets a unique ID for each + # wrapped pure function, preventing cache collisions between different pure modules. + def _jax_forward(fn, other, tree_def, tensors): + """JAX function to compute output and vjp function. + + primals should be a tuple (args, kwargs). + """ + import jax + from jax.tree_util import tree_flatten, tree_unflatten + + def fn_wrapper(*tensors): + # Reconstruct the original args and kwargs + flat_inputs = util.merge(tensors, other) + args, kwargs = tree_unflatten(tree_def, flat_inputs) + return fn(*args, **kwargs) + + return jax.vjp(fn_wrapper, *tensors) + + + def _jax_backward(vjp_spec, saved_tensors, grad_out): + """JAX function to compute input gradients. + + Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. + """ + from jax.tree_util import tree_unflatten + fun_vjp = tree_unflatten(vjp_spec, saved_tensors) + return fun_vjp(grad_out) + + @wraps(fn) def inner(*args, **kwargs): from jax.tree_util import tree_flatten @@ -290,36 +322,6 @@ def backward(ctx, *grad_out): return inner -# NOTE(qihqi): This function cannot be inlined from the callsite -# Becuase if it does, then it won't hit the compilation cache for -# call_jax. Call jax uses functions' id as key. -def _jax_forward(fn, other, tree_def, tensors): - """JAX function to compute output and vjp function. - - primals should be a tuple (args, kwargs). - """ - import jax - from jax.tree_util import tree_flatten, tree_unflatten - - def fn_wrapper(*tensors): - # Reconstruct the original args and kwargs - flat_inputs = util.merge(tensors, other) - args, kwargs = tree_unflatten(tree_def, flat_inputs) - return fn(*args, **kwargs) - - return jax.vjp(fn_wrapper, *tensors) - - -def _jax_backward(vjp_spec, saved_tensors, grad_out): - """JAX function to compute input gradients. - - Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. - """ - from jax.tree_util import tree_unflatten - fun_vjp = tree_unflatten(vjp_spec, saved_tensors) - return fun_vjp(grad_out) - - fori_loop = torch_view(jax.lax.fori_loop) From 2103e2c36ba921376168e7a6e842166417e135ae Mon Sep 17 00:00:00 2001 From: Jialei Chen Date: Mon, 25 Aug 2025 22:40:50 +0000 Subject: [PATCH 2/2] format --- torchax/torchax/interop.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 37e766e2302..34ab79b1083 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -249,16 +249,15 @@ def _jax_forward(fn, other, tree_def, tensors): """ import jax from jax.tree_util import tree_flatten, tree_unflatten - + def fn_wrapper(*tensors): # Reconstruct the original args and kwargs flat_inputs = util.merge(tensors, other) args, kwargs = tree_unflatten(tree_def, flat_inputs) return fn(*args, **kwargs) - + return jax.vjp(fn_wrapper, *tensors) - - + def _jax_backward(vjp_spec, saved_tensors, grad_out): """JAX function to compute input gradients. @@ -268,7 +267,6 @@ def _jax_backward(vjp_spec, saved_tensors, grad_out): fun_vjp = tree_unflatten(vjp_spec, saved_tensors) return fun_vjp(grad_out) - @wraps(fn) def inner(*args, **kwargs): from jax.tree_util import tree_flatten