From 1699011bbd120a983763c4c17eade251c82805d0 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 3 May 2024 13:18:55 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- torchtitan/models/norms.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index a6452cf62e..c35ffecb04 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -213,6 +213,13 @@ def _rms_norm_bwd_kernel_sm( tl.store(DW + row_block_id * N + cols, dw, mask=mask) +def get_sm_count(): + # TODO(whc) initializing sm_count here prevents the need to query device during tracing using meta-device. + # But it also forces us to pick a favorite device. + # possibly, we could build a dict of device_id: sm_count here, and then index it during forward instead. + return torch.cuda.get_device_properties(0).multi_processor_count + + class TritonFusedRMSNorm(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, eps): @@ -268,7 +275,7 @@ def backward(ctx, dy): dx = torch.empty_like(x) dw = torch.empty_like(weight) - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + sm_count = get_sm_count() _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) max_size = 65536 // x.element_size()