Skip to content

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented Feb 25, 2022

Prototype version with "cuDNN conv style" caching.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you tell me why this test skips torch.bfloat16 while the kernel seems to support it https://github.com/NVIDIA/apex/blob/5fa9b1e59b6fbaeabd5e4c5b592df1b6b52ae215/csrc/instance_norm_nvfuser_kernel.cu#L128?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need to test/evaluate bf16 by default as it would remove Volta/sm_70 support? Which I believe was the original use-case/target for instance-norm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, currently eager mode instance_norm supports bfloat16 even on pre-ampere GPUs, so with nvfuser we'd lose that support? We could fall back to old implementation then, I guess.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just be an artifact of how I'm prototyping things: naively generating bf16 casts and expecting it to work, haven't confirmed that it would be caused by nvfuser.

Copy link

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable (but of course there'll need to be a few changes when moving upstream)
What is the latency of compiling the kernel for the first time?
What is the latency if the kernel is found in cache?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you know _input is contiguous? It's a user-facing function, it can get input of any contiguity

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, some more checks need to be added, for now contiguous is assumed to simplify the support matrix, though non-contiguous can be supported by tweaking the caching for kernels and potentially more recompilation.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is eps part of the key, but momentum isn't? Also, can they be passed as args and not constants, so that the kernel doesn't have to be recompiled?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither should need to be part of the key.

Copy link

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can there be multiple kernels per fusion key or, since all the inputs are forced to contiguous, there's a single kernel only, and tuning is runtime at the level of thread/block sizes?

@eqy
Copy link
Collaborator Author

eqy commented Mar 10, 2022

Also, can there be multiple kernels per fusion key or, since all the inputs are forced to contiguous, there's a single kernel only, and tuning is runtime at the level of thread/block sizes?

My understanding is that there would only be one compilation per fusion key; the restriction on contiguous is for initial testing purposes. The kernel is supposed to be general for multiple shapes, but the fuser folks would know more on launch bounds tuning.

@ngimel
Copy link

ngimel commented Mar 10, 2022

My understanding is that there would only be one compilation per fusion key;

I don't think that's true, at the very least there could be generated kernels with int32 indexing for smaller tensors, and int64 indexing for larger ones, there could be other situations where different parallelization strategy will be chosen depending on the sizes?

@eqy
Copy link
Collaborator Author

eqy commented Mar 10, 2022

My understanding is that there would only be one compilation per fusion key;

I don't think that's true, at the very least there could be generated kernels with int32 indexing for smaller tensors, and int64 indexing for larger ones, there could be other situations where different parallelization strategy will be chosen depending on the sizes?

Sorry, I mean "compilation step," or rather that if something is in the cache, it would be unexpected for running the fusion to trigger another compilation. But in general I also agree that more data needs to be collected for the latency of first execution and subsequent executions.

@ngimel
Copy link

ngimel commented Mar 10, 2022

Sorry, I mean "compilation step," or rather that if something is in the cache, it would be unexpected for running the fusion to trigger another compilation. But in general I also agree that more data needs to be collected for the latency of first execution and subsequent executions.

I've checked the code, and it is definitely possible to trigger another compilation even if something is in the cache, for example, for differently aligned input, or for some very different input sizes (I don't know which exact attributes of the tensors, other than alignment, will trigger recompile), or for large inputs (where int64 indexing has to be used). Which is fine, and the existing cache still saves some computation, but it should be documented that even hitting function level cache doesn't guarantee fast turnaround.
It would be good to understand which parameter ranges will result in kernelRuntime match in runFusionWithInputs.
And we still need the measurements for cache miss and full cache hit (function level cache is hit, and no recompilation is triggered in runFusionWithInputs)

@csarofeen
Copy link
Collaborator

You're correct @ngimel. We can end up generating many kernels, if we want to limit the number of kernels we generate we would need implement coarser grained heuristics (definitely possible to do but we haven't done it yet). Cache misses will be dominated by nvrtc compile time. Cache hits are also not really easy to quantify as our finest grained cache is on input sizes, then there'd be a cache for heuristics, then this top level cache. The lowest level miss will not force recompilation so will be cheaper, but the higher two will require recompilation and will be dominated by nvrtc compile time. I'll work with Eddie to measure the hits at the lower two levels.

@ngimel
Copy link

ngimel commented Mar 10, 2022

Awesome, thanks! So as far as I understand, lowest level will adjust launch parameters, maybe dynamic shared memory size and things like that? For my education, can you roughly describe what'll trigger recompilation? E.g. I compiled for a tensor of size (B,C,H,W), then how different (B1,C1,H1,W1) should be to trigger it? Or if you have documentation somewhere that I can look up?

@csarofeen
Copy link
Collaborator

The highest level that Eddie was working on, or the heuristics? Heuristic recompilations are all dependent on heuristic changes and subject to change from one release to the next.

Practically for InstanceNorm 3D the big switches are:
Persistent vs non-persistent (we're working on grid persistence now)
Alignment on vectorization size (1, 2, 4, 8)
If not vectorized same factors for unrolling (depends on what fits well)
Grid reduction vs non-grid reduction
For Persistent cases these decisions are returned by: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/normalization.h#L21
which returns this reduction heuristic structure, if that structure is returned unchanged by this fun mega "==" check: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h#L111-L140
then no recompilation has to happen.

The function to determine these parameters is quite complex and "predictability" (obviously we know what will happen in an instance because code) may be relatively low as things can easily change as we continue to tune the heuristics.

If this is unacceptable for eager mode (I can imagine it may be as we don't want to recompile too much), we could generate simpler/flatter heuristics in place of the current ones which is more oriented to maximum performance at compilation cost. The thing we don't understand is how quickly these heuristics would converge, in practice on something like large transformers the answer is really quick even with dynamic sequence sizes. If we had something instead like a 3D segmentation model where inputs could change in all directions frequently, that's a lot more degrees of freedom and something like this approach may not be applicable. I don't know how we could easily solve the latter problem without impacting the former case (lower performance but less recompilation).

The one good news about this, is the heuristic structure is what determines recompilation, i.e. if we reimplemented https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp#L35-L492 to be coarser, recompilation will be less frequent, so we do have have control over this as this was required to not have to constantly change lots of code as we update heuristics.

@ngimel
Copy link

ngimel commented Mar 10, 2022

Thanks, that's very helpful!

@eqy
Copy link
Collaborator Author

eqy commented Mar 15, 2022

Some ballpark numbers on the basic compile-cache-reuse workflow on the small workload in the tests; first execution is around 0.5-0.9s on V100, with about 30-60us of that being the actual kernel execution time and the bulk of it being the compilation. After the first execution the cache lookup takes around 0.9-1.5us.

@ngimel
Copy link

ngimel commented Mar 16, 2022

0.9-1.5us is great

@eqy eqy force-pushed the instance_norm_nvfuser branch from e9dc120 to ae9ceb8 Compare March 25, 2022 04:06
@eqy eqy changed the title [WIP] [DO NOT MERGE] NVFuser JIT eager mode InstanceNorm3d NVFuser JIT eager mode InstanceNorm3d Mar 25, 2022
crcrpar added a commit to crcrpar/apex that referenced this pull request Mar 25, 2022
Squashed commit of the following:

commit ae9ceb8
Author: Eddie Yan <[email protected]>
Date:   Tue Mar 15 23:48:19 2022 +0000

    add profile, remove scalars from cache key

commit 2631f3b
Author: Eddie Yan <[email protected]>
Date:   Thu Mar 10 02:57:13 2022 +0000

    sketchy test numerics twiddling

commit d91a15c
Author: Eddie Yan <[email protected]>
Date:   Thu Mar 10 01:50:38 2022 +0000

    fix

commit 2003c45
Author: Eddie Yan <[email protected]>
Date:   Thu Mar 10 01:30:08 2022 +0000

    address comments, cleanup

commit a4e052d
Author: Eddie Yan <[email protected]>
Date:   Tue Mar 1 19:28:58 2022 +0000

    add weight and bias check

commit 68053f7
Author: Eddie Yan <[email protected]>
Date:   Fri Feb 25 23:37:44 2022 +0000

    initial check in

Co-authored-by: Eddie Yan <[email protected]>
@crcrpar crcrpar modified the milestones: 22.04, 22.06 May 3, 2022
@crcrpar
Copy link
Collaborator

crcrpar commented May 3, 2022

@crcrpar
Copy link
Collaborator

crcrpar commented Jun 22, 2022

Getting closer to remove PYTORCH_HOME env var dependency. pytorch/pytorch#78281

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final nit-pickings expect to be accessed once dependent headers get shipped...

Comment on lines +10 to +12
PYTORCH_HOME = os.path.abspath(os.environ['PYTORCH_HOME']) if 'PYTORCH_HOME' in os.environ else None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PYTORCH_HOME = os.path.abspath(os.environ['PYTORCH_HOME']) if 'PYTORCH_HOME' in os.environ else None

Comment on lines +340 to +364
if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME):
print(PYTORCH_HOME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're to treat this as a --cuda_ext...

Suggested change
if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME):
print(PYTORCH_HOME)

Comment on lines 71 to 77
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)):
self.dtype = dtype
self.track_running_stats = track_running_stats
self.channels_last = channels_last
self.affine = affine
self.init_modules()
self.check_same_output()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make failure messages more informative...

Suggested change
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)):
self.dtype = dtype
self.track_running_stats = track_running_stats
self.channels_last = channels_last
self.affine = affine
self.init_modules()
self.check_same_output()
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)):
with self.subTest(dtype=dtype, track_running_stats=track_running_stats, channels_last=channels_last, affine=affine):
self.dtype = dtype
self.track_running_stats = track_running_stats
self.channels_last = channels_last
self.affine = affine
self.init_modules()
self.check_same_output()

see: https://docs.python.org/3/library/unittest.html#unittest.TestCase.subTest

Comment on lines 242 to 252
auto result = instance_norm_backward(_input,
_grad_output,
_weight,
_running_mean,
_running_var,
_save_mean,
_save_invstd,
use_input_stats,
_eps,
{true, true, true}, // TODO: is output mask useful?
channels_last);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit-picking] could you lint these lines as it seems like you've interchangeably used spaces and tabs...

#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>

#include <aten/src/ATen/native/utils/ParamsHash.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <aten/src/ATen/native/utils/ParamsHash.h>
#include <ATen/native/utils/ParamsHash.h>

at::Tensor save_invstd,
const bool use_input_stats,
const float eps,
// const std::vector<bool>& output_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this lack of output_mask worth mentioning in instance_norm.py?

CUDAExtension('instance_norm_nvfuser_cuda',
['csrc/instance_norm_nvfuser.cpp', 'csrc/instance_norm_nvfuser_kernel.cu'],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros,
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])},
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros)},

@crcrpar crcrpar modified the milestones: 22.06, 22.07 Jun 23, 2022
@crcrpar crcrpar modified the milestones: 22.07, 22.08 Jul 1, 2022
@eqy eqy force-pushed the instance_norm_nvfuser branch from ae9ceb8 to 3f06e61 Compare July 25, 2022 20:33
@crcrpar crcrpar modified the milestones: 22.08, 22.09 Aug 2, 2022
@eqy eqy force-pushed the instance_norm_nvfuser branch from 3f06e61 to 4af877f Compare September 20, 2022 20:24
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

y = torch.randn(2, device=device)
pred = model(x)
loss = nn.functional.mse_loss(pred, y.float())
loss.backward()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add

if __name__ == '__main__':
    unittest.main()

here so that we can run this file via command line using unittest? (as is, pytest can run this test though)

def test_multigpu(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit-picking to super() to use python3 syntax

Suggested change
super(Model, self).__init__()
super().__init__()

@crcrpar crcrpar mentioned this pull request Feb 9, 2023
crcrpar added a commit to NVIDIA/Fuser that referenced this pull request Jun 1, 2023
Ref:
- NVIDIA/apex#1564: Python frontend based
- NVIDIA/apex#1309: eqy's original
implementation

Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: Jacob Hinkle <[email protected]>
crcrpar added a commit to NVIDIA/Fuser that referenced this pull request Jun 2, 2023
Ref:
- NVIDIA/apex#1564: Python frontend based
- NVIDIA/apex#1309: eqy's original
implementation

Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: Jacob Hinkle <[email protected]>
crcrpar added a commit to NVIDIA/Fuser that referenced this pull request Jun 5, 2023
Ref:
- NVIDIA/apex#1564: Python frontend based
- NVIDIA/apex#1309: eqy's original
implementation

Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: Jacob Hinkle <[email protected]>
jacobhinkle added a commit to NVIDIA/Fuser that referenced this pull request Jun 14, 2023
now that nvfuser is a separate repository, it would be reasonable to
host some layers based off of nvfuser python API.

Ref:
- NVIDIA/apex#1564: Python frontend based
- NVIDIA/apex#1309: eqy's original
implementation

cc @ptrblck

---------

Signed-off-by: Masaki Kozuki <[email protected]>
Co-authored-by: Jacob Hinkle <[email protected]>
Co-authored-by: Jacob Hinkle <[email protected]>
Co-authored-by: jjsjann123 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants