diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index a6452cf62e..c35ffecb04 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -213,6 +213,13 @@ def _rms_norm_bwd_kernel_sm( tl.store(DW + row_block_id * N + cols, dw, mask=mask) +def get_sm_count(): + # TODO(whc) initializing sm_count here prevents the need to query device during tracing using meta-device. + # But it also forces us to pick a favorite device. + # possibly, we could build a dict of device_id: sm_count here, and then index it during forward instead. + return torch.cuda.get_device_properties(0).multi_processor_count + + class TritonFusedRMSNorm(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, eps): @@ -268,7 +275,7 @@ def backward(ctx, dy): dx = torch.empty_like(x) dw = torch.empty_like(weight) - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + sm_count = get_sm_count() _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) max_size = 65536 // x.element_size()