Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions torchax/torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax):
the PyTorch autograd framework by saving the residuals into the context object.
"""

# NOTE(qihqi): This function cannot be inlined from the callsite
# Becuase if it does, then it won't hit the compilation cache for
# call_jax. Call jax uses functions' id as key.
# It is nested inside j2t_autograd to ensure it gets a unique ID for each
# wrapped pure function, preventing cache collisions between different pure modules.
def _jax_forward(fn, other, tree_def, tensors):
"""JAX function to compute output and vjp function.

primals should be a tuple (args, kwargs).
"""
import jax
from jax.tree_util import tree_flatten, tree_unflatten

def fn_wrapper(*tensors):
# Reconstruct the original args and kwargs
flat_inputs = util.merge(tensors, other)
args, kwargs = tree_unflatten(tree_def, flat_inputs)
return fn(*args, **kwargs)

return jax.vjp(fn_wrapper, *tensors)

def _jax_backward(vjp_spec, saved_tensors, grad_out):
"""JAX function to compute input gradients.

Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
"""
from jax.tree_util import tree_unflatten
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
return fun_vjp(grad_out)

@wraps(fn)
def inner(*args, **kwargs):
from jax.tree_util import tree_flatten
Expand Down Expand Up @@ -290,36 +320,6 @@ def backward(ctx, *grad_out):
return inner


# NOTE(qihqi): This function cannot be inlined from the callsite
# Becuase if it does, then it won't hit the compilation cache for
# call_jax. Call jax uses functions' id as key.
def _jax_forward(fn, other, tree_def, tensors):
"""JAX function to compute output and vjp function.

primals should be a tuple (args, kwargs).
"""
import jax
from jax.tree_util import tree_flatten, tree_unflatten

def fn_wrapper(*tensors):
# Reconstruct the original args and kwargs
flat_inputs = util.merge(tensors, other)
args, kwargs = tree_unflatten(tree_def, flat_inputs)
return fn(*args, **kwargs)

return jax.vjp(fn_wrapper, *tensors)


def _jax_backward(vjp_spec, saved_tensors, grad_out):
"""JAX function to compute input gradients.

Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
"""
from jax.tree_util import tree_unflatten
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
return fun_vjp(grad_out)


fori_loop = torch_view(jax.lax.fori_loop)


Expand Down
Loading