Skip to content

Allow all shardings if exported.nr_devices is 1 in _export.py. #29690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
@@ -1472,22 +1472,12 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
num_devices = axis_context.axis_env.nreps
else:
raise NotImplementedError(type(axis_context))
if num_devices != exported.nr_devices:
# In some special cases we allow running with a different number of devices
# than the function was exported for.
err_msg = ""
if exported.nr_devices != 1:
err_msg = "the function was exported for more than 1 device."
elif (_check_module(submodule, disabled_checks=(), shardy_enabled=shardy_enabled)
or any(s is not None and not s.is_replicated()
for s in exported.in_shardings_hlo + exported.out_shardings_hlo)):
err_msg = "the function contains non-replicated sharding annotations."
if err_msg:
raise ValueError(
if num_devices != exported.nr_devices and exported.nr_devices != 1:
raise ValueError(
f"Function {exported.fun_name} was exported for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices. This is disallowed because: {err_msg}"
)
f"{num_devices} devices, which is not allowed."
)

# Apply in_shardings
if shardy_enabled:
63 changes: 55 additions & 8 deletions tests/export_test.py
Original file line number Diff line number Diff line change
@@ -1250,7 +1250,7 @@ def f_without_shardings(x):
res_exported = exp.call(b)
self.assertAllClose(res_native, res_exported)

def test_call_with_different_no_of_devices_error_has_in_shardings(self):
def test_call_with_different_no_of_devices_in_shardings_success(self):
if jax.local_device_count() < 2:
self.skipTest("Need at least 2 devices")

@@ -1263,18 +1263,41 @@ def f_with_sharding(x):
a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
(jax.device_count(), 10)
)
res_native = f_with_sharding(a)
exp = get_exported(f_with_sharding)(a)
self.assertEqual(exp.nr_devices, 1)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

res_exported = exp.call(b)
self.assertAllClose(res_native, res_exported)

def test_call_with_different_no_of_devices_in_shardings_error(self):
if jax.local_device_count() < 3:
self.skipTest("Need at least 3 devices")

mesh_1 = Mesh(jax.local_devices()[:2], "i")
@functools.partial(pjit.pjit,
in_shardings=NamedSharding(mesh_1, P("i")))
def f_with_sharding(x):
return jnp.sum(x ** 2, axis=0)

a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
(jax.device_count(), 10)
)
exp = get_exported(f_with_sharding)(a)
self.assertEqual(exp.nr_devices, 2)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

with self.assertRaisesRegex(
ValueError,
"Function .* was exported for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* function contains "
"non-replicated sharding annotations"):
"Function .* was exported for 2 devices and is called in a "
f"context with {jax.local_device_count()} devices"):
exp.call(b)

def test_call_with_different_no_of_devices_pmap(self):
@@ -1296,7 +1319,7 @@ def f_jax(x):
res_exported = jax.pmap(exp.call)(b)
self.assertAllClose(res_native, res_exported[0])

def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
def test_call_with_different_no_of_devices_sharding_constraint_success(self):
if jax.device_count() < 2:
self.skipTest("Need at least 2 devices")

@@ -1309,18 +1332,42 @@ def f_with_sharding(x):
a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
(jax.device_count(), 10)
)
res_native = f_with_sharding(a)
exp = get_exported(f_with_sharding)(a)
self.assertEqual(exp.nr_devices, 1)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

res_exported = exp.call(b)
self.assertAllClose(res_native, res_exported)

def test_call_with_different_no_of_devices_sharding_constraint_error(self):
if jax.device_count() < 3:
self.skipTest("Need at least 3 devices")

# We export for 2 devices, but call with >=3 devices.
mesh_1 = Mesh(jax.local_devices()[:2], "i")
@jax.jit
def f_with_sharding(x):
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh_1, P("i")))
return jnp.sum(x ** 2, axis=0)

a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
(jax.device_count(), 10)
)
exp = get_exported(f_with_sharding)(a)
self.assertEqual(exp.nr_devices, 2)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

with self.assertRaisesRegex(
ValueError,
"Function .* was exported for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* function contains "
"non-replicated sharding annotations"):
"Function .* was exported for 2 devices and is called in a "
f"context with {jax.local_device_count()} devices"):
exp.call(b)

@jtu.parameterized_filterable(