Skip to content

InstanceNorm3dNVFuser issue #1562

@yiheng-wang-nv

Description

@yiheng-wang-nv

Describe the Bug

Minimal Steps/Code to Reproduce the Bug

I used the origin/22.12-dev branch with commit: 7afa66e

By running the following code:

from apex.normalization import InstanceNorm3dNVFuser
import torch

layer = InstanceNorm3dNVFuser(num_features=16, affine=False).to("cuda:0")
layer(torch.randn((1, 16, 8, 8, 8)).to("cuda:0"))

Error happened like:

File /usr/local/lib/python3.8/dist-packages/apex/normalization/instance_norm.py:138, in _InstanceNormNVFuser.forward(self, input)
    133     out = InstanceNormNVFuserFunction.apply(
    134             input, self.weight if self.weight is not None else self.dummy,
    135             self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var,
    136             self.training or not self.track_running_stats, self.momentum, self.eps)
    137 else:
--> 138     out = InstanceNormNVFuserFunction.apply(
    139             input, self.weight if self.weight is not None else self.dummy,
    140             self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy,
    141             self.training or not self.track_running_stats, self.momentum, self.eps)
    142 return out

File /usr/local/lib/python3.8/dist-packages/apex/normalization/instance_norm.py:25, in InstanceNormNVFuserFunction.forward(ctx, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps)
     23     _input = input
     24 assert _input.is_contiguous()
---> 25 result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var,
     26                                             use_input_stats, momentum, eps, channels_last)
     27 if len(result) == 3:
     28     out, mean, invstd = result

RuntimeError: !mismatch INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/executor_utils.cpp":388, please report a bug to PyTorch. Found one or more invalid arguments: Argument is scalar type, but kernel parameter is not
Argument is scalar type, but kernel parameter is not


In addition, no error happen when affine=True

Expected Behavior

It should work correctly as torch.nn.InstanceNorm3d does

Environment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions