-
Notifications
You must be signed in to change notification settings - Fork 1.5k
NVFuser JIT eager mode InstanceNorm3d #1309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
eqy
wants to merge
11
commits into
NVIDIA:master
Choose a base branch
from
eqy:instance_norm_nvfuser
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
b300b59
initial check in
eqy 0b8df57
add weight and bias check
eqy 87349fb
address comments, cleanup
eqy 7297b37
fix
eqy 53b6405
sketchy test numerics twiddling
eqy a0b4669
add profile, remove scalars from cache key
eqy f2b76c6
retab
eqy fd6ea01
some overdue cleanup
eqy 4af877f
fix device for dummy tensor
eqy be7f7c5
add test for multigpu instancenorm3dnvfuser
eqy 285e973
Update instance_norm.py
eqy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm | ||
from .instance_norm import InstanceNorm3dNVFuser |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import importlib | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn.modules.batchnorm import _NormBase | ||
|
||
global instance_norm_nvfuser_cuda | ||
instance_norm_nvfuser_cuda = None | ||
|
||
class InstanceNormNVFuserFunction(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, weight, bias, running_mean, running_var, | ||
use_input_stats, momentum, eps): | ||
global instance_norm_nvfuser_cuda | ||
if instance_norm_nvfuser_cuda is None: | ||
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda") | ||
|
||
channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous(memory_format=torch.channels_last_3d) | ||
if channels_last: | ||
order = [0] + [i for i in range(2, len(input.shape))] + [1] | ||
_input = input.permute(order) | ||
else: | ||
_input = input | ||
assert _input.is_contiguous() | ||
result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var, | ||
use_input_stats, momentum, eps, channels_last) | ||
if len(result) == 3: | ||
out, mean, invstd = result | ||
else: | ||
running_mean, running_var, out, mean, invstd = result | ||
ctx.use_input_stats = use_input_stats | ||
ctx.eps = eps | ||
ctx.channels_last = channels_last | ||
# saving for backward in "explicit channels-last format" | ||
ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd) | ||
if channels_last: | ||
order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)] | ||
out = out.permute(order) | ||
if len(out.shape) == 4: | ||
assert out.is_contiguous(memory_format=torch.channels_last) | ||
assert input.is_contiguous(memory_format=torch.channels_last) | ||
elif len(out.shape) == 5: | ||
assert out.is_contiguous(memory_format=torch.channels_last_3d) | ||
assert input.is_contiguous(memory_format=torch.channels_last_3d) | ||
else: | ||
assert False, "unhandled channels_last format variation in forward" | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
global instance_norm_nvfuser_cuda | ||
if instance_norm_nvfuser_cuda is None: | ||
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda") | ||
|
||
if ctx.channels_last: | ||
order = [0] + [i for i in range(2, len(grad_output.shape))] + [1] | ||
grad_output = grad_output.permute(order) | ||
# input was saved in "explicit channels-last format" | ||
assert ctx.saved_tensors[0].is_contiguous() | ||
grad_output = grad_output.contiguous() | ||
saved = list(ctx.saved_tensors) | ||
saved.insert(1, grad_output) | ||
running_mean = saved[3] | ||
running_var = saved[4] | ||
mean = saved[-2] | ||
var = saved[-1] | ||
grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward(*saved, ctx.use_input_stats, ctx.eps, ctx.channels_last) | ||
if ctx.channels_last: | ||
order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)] | ||
grad_input = grad_input.permute(order) | ||
if len(grad_input.shape) == 4: | ||
assert grad_input.is_contiguous(memory_format=torch.channels_last) | ||
elif len(grad_input.shape) == 5: | ||
assert grad_input.is_contiguous(memory_format=torch.channels_last_3d) | ||
else: | ||
assert False, "unhandled channels_last format variation in backward" | ||
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None | ||
|
||
|
||
class _InstanceNormNVFuser(_NormBase): | ||
def __init__( | ||
self, | ||
num_features: int, | ||
eps: float = 1e-5, | ||
momentum: float = 0.1, | ||
affine: bool = False, | ||
track_running_stats: bool = False, | ||
device=None, | ||
dtype=None | ||
) -> None: | ||
factory_kwargs = {'device': device, 'dtype': dtype} | ||
super(_InstanceNormNVFuser, self).__init__( | ||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) | ||
self.dummy = torch.empty([], device=device) | ||
|
||
def _check_input_dim(self, input): | ||
raise NotImplementedError | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
version = local_metadata.get('version', None) | ||
# at version 1: removed running_mean and running_var when | ||
# track_running_stats=False (default) | ||
if version is None and not self.track_running_stats: | ||
running_stats_keys = [] | ||
for name in ('running_mean', 'running_var'): | ||
key = prefix + name | ||
if key in state_dict: | ||
running_stats_keys.append(key) | ||
if len(running_stats_keys) > 0: | ||
error_msgs.append( | ||
'Unexpected running stats buffer(s) {names} for {klass} ' | ||
'with track_running_stats=False. If state_dict is a ' | ||
'checkpoint saved before 0.4.0, this may be expected ' | ||
'because {klass} does not track running stats by default ' | ||
'since 0.4.0. Please remove these keys from state_dict. If ' | ||
'the running stats are actually needed, instead set ' | ||
'track_running_stats=True in {klass} to enable them. See ' | ||
'the documentation of {klass} for details.' | ||
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys), | ||
klass=self.__class__.__name__)) | ||
for key in running_stats_keys: | ||
state_dict.pop(key) | ||
|
||
super(_InstanceNormNVFuser, self)._load_from_state_dict( | ||
state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs) | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" | ||
self._check_input_dim(input) | ||
if self.dummy.device != input.device: | ||
self.dummy = torch.empty([], device=input.device) | ||
if self.running_mean is not None: | ||
out = InstanceNormNVFuserFunction.apply( | ||
input, self.weight if self.weight is not None else self.dummy, | ||
self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var, | ||
self.training or not self.track_running_stats, self.momentum, self.eps) | ||
else: | ||
out = InstanceNormNVFuserFunction.apply( | ||
input, self.weight if self.weight is not None else self.dummy, | ||
self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy, | ||
self.training or not self.track_running_stats, self.momentum, self.eps) | ||
return out | ||
|
||
class InstanceNorm3dNVFuser(_InstanceNormNVFuser): | ||
def _check_input_dim(self, input): | ||
if input.dim() != 5: | ||
raise ValueError('expected 5D input (got {}D input)' | ||
.format(input.dim())) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#include <iostream> | ||
#include <vector> | ||
|
||
#include <torch/extension.h> | ||
|
||
std::vector<at::Tensor> instance_norm_nvfuser_forward( | ||
at::Tensor input, | ||
at::Tensor weight, | ||
at::Tensor bias, | ||
at::Tensor run_mean, | ||
at::Tensor run_var, | ||
const bool use_input_stats, | ||
const float momentum, | ||
const float eps, | ||
const bool channels_last = false | ||
); | ||
|
||
std::vector<at::Tensor> instance_norm_nvfuser_backward( | ||
at::Tensor input, | ||
at::Tensor grad_output, | ||
at::Tensor weight, | ||
at::Tensor running_mean, | ||
at::Tensor running_var, | ||
at::Tensor save_mean, | ||
at::Tensor save_invstd, | ||
const bool use_input_stats, | ||
const float eps, | ||
// const std::vector<bool>& output_mask, | ||
bool channels_last = false | ||
); | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &instance_norm_nvfuser_forward, "instance_norm forward (CUDA)"); | ||
m.def("backward", &instance_norm_nvfuser_backward, "instance_norm backward (CUDA)"); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do you know
_input
is contiguous? It's a user-facing function, it can get input of any contiguityThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, some more checks need to be added, for now contiguous is assumed to simplify the support matrix, though non-contiguous can be supported by tweaking the caching for kernels and potentially more recompilation.