-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the Bug
Minimal Steps/Code to Reproduce the Bug
I used the origin/22.12-dev
branch with commit: 7afa66e
By running the following code:
from apex.normalization import InstanceNorm3dNVFuser
import torch
layer = InstanceNorm3dNVFuser(num_features=16, affine=False).to("cuda:0")
layer(torch.randn((1, 16, 8, 8, 8)).to("cuda:0"))
Error happened like:
File /usr/local/lib/python3.8/dist-packages/apex/normalization/instance_norm.py:138, in _InstanceNormNVFuser.forward(self, input)
133 out = InstanceNormNVFuserFunction.apply(
134 input, self.weight if self.weight is not None else self.dummy,
135 self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var,
136 self.training or not self.track_running_stats, self.momentum, self.eps)
137 else:
--> 138 out = InstanceNormNVFuserFunction.apply(
139 input, self.weight if self.weight is not None else self.dummy,
140 self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy,
141 self.training or not self.track_running_stats, self.momentum, self.eps)
142 return out
File /usr/local/lib/python3.8/dist-packages/apex/normalization/instance_norm.py:25, in InstanceNormNVFuserFunction.forward(ctx, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps)
23 _input = input
24 assert _input.is_contiguous()
---> 25 result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var,
26 use_input_stats, momentum, eps, channels_last)
27 if len(result) == 3:
28 out, mean, invstd = result
RuntimeError: !mismatch INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/executor_utils.cpp":388, please report a bug to PyTorch. Found one or more invalid arguments: Argument is scalar type, but kernel parameter is not
Argument is scalar type, but kernel parameter is not
In addition, no error happen when affine=True
Expected Behavior
It should work correctly as torch.nn.InstanceNorm3d
does
Environment
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working