-
Notifications
You must be signed in to change notification settings - Fork 1.5k
NVFuser JIT eager mode InstanceNorm3d #1309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
csrc/instance_norm_nvfuser_kernel.cu
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you tell me why this test skips torch.bfloat16
while the kernel seems to support it https://github.com/NVIDIA/apex/blob/5fa9b1e59b6fbaeabd5e4c5b592df1b6b52ae215/csrc/instance_norm_nvfuser_kernel.cu#L128?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we need to test/evaluate bf16 by default as it would remove Volta/sm_70
support? Which I believe was the original use-case/target for instance-norm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into testing this via https://pytorch.org/docs/stable/generated/torch.cuda.get_device_capability.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, currently eager mode instance_norm supports bfloat16 even on pre-ampere GPUs, so with nvfuser we'd lose that support? We could fall back to old implementation then, I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could just be an artifact of how I'm prototyping things: naively generating bf16 casts and expecting it to work, haven't confirmed that it would be caused by nvfuser.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable (but of course there'll need to be a few changes when moving upstream)
What is the latency of compiling the kernel for the first time?
What is the latency if the kernel is found in cache?
apex/normalization/instance_norm.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do you know _input
is contiguous? It's a user-facing function, it can get input of any contiguity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, some more checks need to be added, for now contiguous is assumed to simplify the support matrix, though non-contiguous can be supported by tweaking the caching for kernels and potentially more recompilation.
csrc/instance_norm_nvfuser_kernel.cu
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is eps
part of the key, but momentum
isn't? Also, can they be passed as args and not constants, so that the kernel doesn't have to be recompiled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neither should need to be part of the key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can there be multiple kernels per fusion key or, since all the inputs are forced to contiguous, there's a single kernel only, and tuning is runtime at the level of thread/block sizes?
My understanding is that there would only be one compilation per fusion key; the restriction on contiguous is for initial testing purposes. The kernel is supposed to be general for multiple shapes, but the fuser folks would know more on launch bounds tuning. |
I don't think that's true, at the very least there could be generated kernels with int32 indexing for smaller tensors, and int64 indexing for larger ones, there could be other situations where different parallelization strategy will be chosen depending on the sizes? |
Sorry, I mean "compilation step," or rather that if something is in the cache, it would be unexpected for running the fusion to trigger another compilation. But in general I also agree that more data needs to be collected for the latency of first execution and subsequent executions. |
I've checked the code, and it is definitely possible to trigger another compilation even if something is in the cache, for example, for differently aligned input, or for some very different input sizes (I don't know which exact attributes of the tensors, other than alignment, will trigger recompile), or for large inputs (where int64 indexing has to be used). Which is fine, and the existing cache still saves some computation, but it should be documented that even hitting function level cache doesn't guarantee fast turnaround. |
You're correct @ngimel. We can end up generating many kernels, if we want to limit the number of kernels we generate we would need implement coarser grained heuristics (definitely possible to do but we haven't done it yet). Cache misses will be dominated by nvrtc compile time. Cache hits are also not really easy to quantify as our finest grained cache is on input sizes, then there'd be a cache for heuristics, then this top level cache. The lowest level miss will not force recompilation so will be cheaper, but the higher two will require recompilation and will be dominated by nvrtc compile time. I'll work with Eddie to measure the hits at the lower two levels. |
Awesome, thanks! So as far as I understand, lowest level will adjust launch parameters, maybe dynamic shared memory size and things like that? For my education, can you roughly describe what'll trigger recompilation? E.g. I compiled for a tensor of size (B,C,H,W), then how different (B1,C1,H1,W1) should be to trigger it? Or if you have documentation somewhere that I can look up? |
The highest level that Eddie was working on, or the heuristics? Heuristic recompilations are all dependent on heuristic changes and subject to change from one release to the next. Practically for InstanceNorm 3D the big switches are: The function to determine these parameters is quite complex and "predictability" (obviously we know what will happen in an instance because code) may be relatively low as things can easily change as we continue to tune the heuristics. If this is unacceptable for eager mode (I can imagine it may be as we don't want to recompile too much), we could generate simpler/flatter heuristics in place of the current ones which is more oriented to maximum performance at compilation cost. The thing we don't understand is how quickly these heuristics would converge, in practice on something like large transformers the answer is really quick even with dynamic sequence sizes. If we had something instead like a 3D segmentation model where inputs could change in all directions frequently, that's a lot more degrees of freedom and something like this approach may not be applicable. I don't know how we could easily solve the latter problem without impacting the former case (lower performance but less recompilation). The one good news about this, is the heuristic structure is what determines recompilation, i.e. if we reimplemented https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp#L35-L492 to be coarser, recompilation will be less frequent, so we do have have control over this as this was required to not have to constantly change lots of code as we update heuristics. |
Thanks, that's very helpful! |
Some ballpark numbers on the basic compile-cache-reuse workflow on the small workload in the tests; first execution is around 0.5-0.9s on V100, with about 30-60us of that being the actual kernel execution time and the bulk of it being the compilation. After the first execution the cache lookup takes around 0.9-1.5us. |
0.9-1.5us is great |
e9dc120
to
ae9ceb8
Compare
Squashed commit of the following: commit ae9ceb8 Author: Eddie Yan <[email protected]> Date: Tue Mar 15 23:48:19 2022 +0000 add profile, remove scalars from cache key commit 2631f3b Author: Eddie Yan <[email protected]> Date: Thu Mar 10 02:57:13 2022 +0000 sketchy test numerics twiddling commit d91a15c Author: Eddie Yan <[email protected]> Date: Thu Mar 10 01:50:38 2022 +0000 fix commit 2003c45 Author: Eddie Yan <[email protected]> Date: Thu Mar 10 01:30:08 2022 +0000 address comments, cleanup commit a4e052d Author: Eddie Yan <[email protected]> Date: Tue Mar 1 19:28:58 2022 +0000 add weight and bias check commit 68053f7 Author: Eddie Yan <[email protected]> Date: Fri Feb 25 23:37:44 2022 +0000 initial check in Co-authored-by: Eddie Yan <[email protected]>
I want to have pytorch ship
|
Getting closer to remove |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
final nit-pickings expect to be accessed once dependent headers get shipped...
PYTORCH_HOME = os.path.abspath(os.environ['PYTORCH_HOME']) if 'PYTORCH_HOME' in os.environ else None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PYTORCH_HOME = os.path.abspath(os.environ['PYTORCH_HOME']) if 'PYTORCH_HOME' in os.environ else None |
if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME): | ||
print(PYTORCH_HOME) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're to treat this as a --cuda_ext
...
if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME): | |
print(PYTORCH_HOME) |
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)): | ||
self.dtype = dtype | ||
self.track_running_stats = track_running_stats | ||
self.channels_last = channels_last | ||
self.affine = affine | ||
self.init_modules() | ||
self.check_same_output() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make failure messages more informative...
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)): | |
self.dtype = dtype | |
self.track_running_stats = track_running_stats | |
self.channels_last = channels_last | |
self.affine = affine | |
self.init_modules() | |
self.check_same_output() | |
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)): | |
with self.subTest(dtype=dtype, track_running_stats=track_running_stats, channels_last=channels_last, affine=affine): | |
self.dtype = dtype | |
self.track_running_stats = track_running_stats | |
self.channels_last = channels_last | |
self.affine = affine | |
self.init_modules() | |
self.check_same_output() |
see: https://docs.python.org/3/library/unittest.html#unittest.TestCase.subTest
auto result = instance_norm_backward(_input, | ||
_grad_output, | ||
_weight, | ||
_running_mean, | ||
_running_var, | ||
_save_mean, | ||
_save_invstd, | ||
use_input_stats, | ||
_eps, | ||
{true, true, true}, // TODO: is output mask useful? | ||
channels_last); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit-picking] could you lint these lines as it seems like you've interchangeably used spaces and tabs...
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h> | ||
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h> | ||
|
||
#include <aten/src/ATen/native/utils/ParamsHash.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include <aten/src/ATen/native/utils/ParamsHash.h> | |
#include <ATen/native/utils/ParamsHash.h> |
at::Tensor save_invstd, | ||
const bool use_input_stats, | ||
const float eps, | ||
// const std::vector<bool>& output_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this lack of output_mask
worth mentioning in instance_norm.py
?
CUDAExtension('instance_norm_nvfuser_cuda', | ||
['csrc/instance_norm_nvfuser.cpp', 'csrc/instance_norm_nvfuser_kernel.cu'], | ||
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros, | ||
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])}, | |
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros)}, |
ae9ceb8
to
3f06e61
Compare
3f06e61
to
4af877f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
y = torch.randn(2, device=device) | ||
pred = model(x) | ||
loss = nn.functional.mse_loss(pred, y.float()) | ||
loss.backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add
if __name__ == '__main__':
unittest.main()
here so that we can run this file via command line using unittest? (as is, pytest can run this test though)
def test_multigpu(self): | ||
class Model(nn.Module): | ||
def __init__(self): | ||
super(Model, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit-picking to super() to use python3 syntax
super(Model, self).__init__() | |
super().__init__() |
Ref: - NVIDIA/apex#1564: Python frontend based - NVIDIA/apex#1309: eqy's original implementation Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: Jacob Hinkle <[email protected]>
Ref: - NVIDIA/apex#1564: Python frontend based - NVIDIA/apex#1309: eqy's original implementation Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: Jacob Hinkle <[email protected]>
Ref: - NVIDIA/apex#1564: Python frontend based - NVIDIA/apex#1309: eqy's original implementation Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: Jacob Hinkle <[email protected]>
now that nvfuser is a separate repository, it would be reasonable to host some layers based off of nvfuser python API. Ref: - NVIDIA/apex#1564: Python frontend based - NVIDIA/apex#1309: eqy's original implementation cc @ptrblck --------- Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: Jacob Hinkle <[email protected]> Co-authored-by: Jacob Hinkle <[email protected]> Co-authored-by: jjsjann123 <[email protected]>
Prototype version with "cuDNN conv style" caching.