Skip to content

Test rand in a fusion with zero tensor input #1932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 150 commits into from
Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
ef1360e
Added api for query kernel argument
jjsjann123 Jul 31, 2022
d4acec1
updating kernel arg
jjsjann123 Aug 1, 2022
d7b6d16
adding size() in kernel arg
jjsjann123 Aug 1, 2022
d3e1707
adding getPointer in TensorArgAbstract
jjsjann123 Aug 1, 2022
1e4f9f5
SchedulerRuntimeInfo to take KernelArgumentHolder
jjsjann123 Aug 1, 2022
9242048
fixing build issues
jjsjann123 Aug 1, 2022
4c01ba7
missing ;
jjsjann123 Aug 1, 2022
b094b5d
fixing more compiler errors
jjsjann123 Aug 1, 2022
e8ea2e4
remove unwanted changes; adding const qualifier to arg()
jjsjann123 Aug 1, 2022
49fdffe
fixing return for const func arg()
jjsjann123 Aug 1, 2022
8e51482
overloading arg() for constness correctness
jjsjann123 Aug 1, 2022
91359bd
multiple definition error
jjsjann123 Aug 1, 2022
d4ebb04
exposing kernel argument holder for test access
jjsjann123 Aug 1, 2022
21a8323
moving index mode computation to KernelArgumentHolder
jjsjann123 Aug 1, 2022
e962fa6
fixing name for function; adding const for getIndexMode()
jjsjann123 Aug 1, 2022
7412fd1
fixing const for static function
jjsjann123 Aug 1, 2022
2fc2606
switching scheduling from ArrayRef<at::Tensor> to KernelArgumentHolder
jjsjann123 Aug 2, 2022
0d2d1a0
fixing compile issues
jjsjann123 Aug 2, 2022
b4b3c6e
fixing more renames
jjsjann123 Aug 2, 2022
9135a52
Converting PrecomputedIntegersBase::FusionPrecomputedIntegers to use
jjsjann123 Aug 2, 2022
850ca93
fixing more KernelArgumentHolder in tests
jjsjann123 Aug 2, 2022
1b9a6b5
refactor device_index & index_mode to KernelArgumentHolder
jjsjann123 Aug 2, 2022
d3674f5
missing type
jjsjann123 Aug 2, 2022
5d40c18
switch KernelArgumentHolder creation in benchmakrs
jjsjann123 Aug 2, 2022
d99dde5
fixing return types
jjsjann123 Aug 2, 2022
f7ceabd
removing duplicated definition
jjsjann123 Aug 2, 2022
f4740e1
fixing pushing inputs to KernelArgumentHolder
jjsjann123 Aug 2, 2022
a1144a0
fixing args construction with permuted tensor
jjsjann123 Aug 2, 2022
0f899ab
updating push to duplicate entries in KernelArgumentHolder
jjsjann123 Aug 3, 2022
e1d2a9a
fixing string name in macro
jjsjann123 Aug 3, 2022
fa0525f
fixing return type in macro
jjsjann123 Aug 3, 2022
320bea9
fixing make_unique with derived
jjsjann123 Aug 3, 2022
5b01066
removing entries in RuntimeWorkSpace
jjsjann123 Aug 3, 2022
d3e9984
prototyping isCompiled query
jjsjann123 Aug 4, 2022
156dce7
fixing build issue
jjsjann123 Aug 4, 2022
5800c01
fixing typo
jjsjann123 Aug 4, 2022
68d534e
missing braces
jjsjann123 Aug 4, 2022
64ad30a
Merge remote-tracking branch 'origin/devel' into jiej_wip
jjsjann123 Aug 5, 2022
9c6ce2f
fixing permutation in the old code path
jjsjann123 Aug 5, 2022
9436104
probably broken: initial commit for compilation API
jjsjann123 Aug 6, 2022
60facbd
fixing compiler errors
jjsjann123 Aug 6, 2022
b74bc68
fixing build
jjsjann123 Aug 6, 2022
7dabbf4
fixing build
jjsjann123 Aug 6, 2022
8b0cc3c
fixing build
jjsjann123 Aug 6, 2022
3173c19
fixing build
jjsjann123 Aug 6, 2022
54477ce
fixing build
jjsjann123 Aug 6, 2022
0e34b78
fixing build
jjsjann123 Aug 6, 2022
e327773
debug print
jjsjann123 Aug 7, 2022
7adccec
debug print
jjsjann123 Aug 7, 2022
3673e38
updating KernelArgumentHolder update for segmented fusion
jjsjann123 Aug 7, 2022
4e84338
updating KernelArgumentHolder debugPrint
jjsjann123 Aug 7, 2022
c20422e
updating KernelArgumentHolder debugPrint const iterator
jjsjann123 Aug 7, 2022
52d60dd
updating KernelArgumentHolder debugPrint accessor
jjsjann123 Aug 7, 2022
e3316df
updating KernelArgumentHolder debugPrint rename print
jjsjann123 Aug 7, 2022
d44cc06
debug test
jjsjann123 Aug 7, 2022
1f45c09
shrink test size
jjsjann123 Aug 8, 2022
4f15d51
with new AIP
jjsjann123 Aug 8, 2022
fba2c72
add compute launch param in infer outputs
jjsjann123 Aug 8, 2022
ea51ad9
fixing header
jjsjann123 Aug 8, 2022
5e2e159
fixing empty function
jjsjann123 Aug 8, 2022
824dfd9
fixing empty function
jjsjann123 Aug 8, 2022
30bf361
removing dumb assert
jjsjann123 Aug 8, 2022
25f4ada
async compilation in cpp example
jjsjann123 Aug 10, 2022
d2838dc
capture this in lambda
jjsjann123 Aug 10, 2022
ea8f51d
fix mutex double lock
jjsjann123 Aug 10, 2022
401a64b
debug print
jjsjann123 Aug 10, 2022
98730c0
typo
jjsjann123 Aug 10, 2022
146e51f
debug print
jjsjann123 Aug 11, 2022
cd246d5
try to move captures in generalized lambda
jjsjann123 Aug 11, 2022
bad15e8
fixing move
jjsjann123 Aug 11, 2022
cd4b689
fixing lambda capture again
jjsjann123 Aug 11, 2022
56783aa
init list order warning/error
jjsjann123 Aug 11, 2022
243582c
compileAsync by copy
jjsjann123 Aug 11, 2022
0de1e40
compileAsync by copy
jjsjann123 Aug 11, 2022
d0ffec9
mutable lambda
jjsjann123 Aug 11, 2022
5bc9883
removing lambda capture of unique_lock
jjsjann123 Aug 11, 2022
bc7db01
lambda capture of unique_lock
jjsjann123 Aug 11, 2022
0138b76
lambda capture of unique_lock
jjsjann123 Aug 11, 2022
30e750e
fixing mutex release
jjsjann123 Aug 11, 2022
50f8799
remove print; add back while loop query compilation
jjsjann123 Aug 12, 2022
15233c2
update
jjsjann123 Aug 12, 2022
4092de8
update
jjsjann123 Aug 12, 2022
72fe6e0
remove debug print
jjsjann123 Aug 12, 2022
0ce0750
remove debug print
jjsjann123 Aug 12, 2022
effc6a0
further refactor use of pytorch values to KernelArgumentHolder in run…
jjsjann123 Aug 15, 2022
ee6fc44
fix removing type
jjsjann123 Aug 15, 2022
367bcdc
fixing various build issue
jjsjann123 Aug 15, 2022
dd177c5
more build issues
jjsjann123 Aug 15, 2022
687f294
build issues
jjsjann123 Aug 15, 2022
3c720b3
build issues
jjsjann123 Aug 15, 2022
11afe76
build issues
jjsjann123 Aug 15, 2022
b4aa6b6
build issues
jjsjann123 Aug 15, 2022
307bd15
build issues
jjsjann123 Aug 15, 2022
d9df56a
fix test for updated API
jjsjann123 Aug 15, 2022
ba5ee82
fix test for updated API
jjsjann123 Aug 15, 2022
6a663be
debug print
jjsjann123 Aug 15, 2022
7db08f4
debug print
jjsjann123 Aug 15, 2022
774a24b
fixing output puhs in runFusion
jjsjann123 Aug 15, 2022
7505db0
fixing output wiring in FusionKernelRuntime::runWithInput
jjsjann123 Aug 15, 2022
5f40eae
fixing output wiring in FusionKernelRuntime::runWithInput
jjsjann123 Aug 15, 2022
7cc751b
debug print
jjsjann123 Aug 15, 2022
16afaaa
remove empty tensor allocation for alias
jjsjann123 Aug 15, 2022
644e5a2
few more prints
jjsjann123 Aug 15, 2022
bb18cfc
try to fix
jjsjann123 Aug 15, 2022
6e4c760
try to fix
jjsjann123 Aug 15, 2022
5d7998d
update default device on kernel argument holder
jjsjann123 Aug 16, 2022
571c3d6
debug print for device
jjsjann123 Aug 16, 2022
b6849b3
default device/index for empty inputs
jjsjann123 Aug 16, 2022
5076aee
fix return of kernel argument holder
jjsjann123 Aug 16, 2022
b49db18
fix runtime output mapping with aliased outputs
jjsjann123 Aug 16, 2022
d94a1d8
fix typo
jjsjann123 Aug 16, 2022
0a9860b
refactor to keep aliased outputs in runFusion and added a reference o…
jjsjann123 Aug 17, 2022
c074324
update executor to include aliased outputs
jjsjann123 Aug 17, 2022
c8f8657
fixing build
jjsjann123 Aug 17, 2022
cd34a2e
enabling vectorized check
jjsjann123 Aug 17, 2022
ef25798
fixing build
jjsjann123 Aug 17, 2022
5fe3910
fixing build
jjsjann123 Aug 17, 2022
1a1cfd8
fixing header signature
jjsjann123 Aug 17, 2022
b07feeb
fixing tensor type check for binding fusion/kernel inputs
jjsjann123 Aug 17, 2022
a115456
fixing tensor type check and skip binding for cpu tensor
jjsjann123 Aug 17, 2022
b8be76b
remove prints
jjsjann123 Aug 17, 2022
6b97539
Merge remote-tracking branch 'origin/devel' into jiej_wip
jjsjann123 Aug 17, 2022
a47e5a1
fixing conflict resolution
jjsjann123 Aug 17, 2022
26eab71
fixing conflict resolution attempt 2
jjsjann123 Aug 17, 2022
93bde13
code cleaning
jjsjann123 Aug 17, 2022
f66a76f
clang-format
jjsjann123 Aug 17, 2022
8343073
removing prints
jjsjann123 Aug 17, 2022
073b22a
adding validate kernel inputs
jjsjann123 Aug 19, 2022
ba6ac91
fixing build
jjsjann123 Aug 19, 2022
8e32779
fixing build
jjsjann123 Aug 19, 2022
f2a925a
Merge remote-tracking branch 'origin/devel' into async_compilation
jjsjann123 Aug 19, 2022
4596bdd
code cleaning with comments
jjsjann123 Aug 20, 2022
adbdef3
build issues
jjsjann123 Aug 20, 2022
524c7e6
avoid async failure on input check
jjsjann123 Aug 20, 2022
92c2265
rename tests
jjsjann123 Aug 20, 2022
8b591ea
Merge remote-tracking branch 'origin/devel' into HEAD
jjsjann123 Aug 23, 2022
f3bc13a
address review comments
jjsjann123 Aug 24, 2022
9710da7
fixing build
jjsjann123 Aug 24, 2022
a7c1bf5
fixing build
jjsjann123 Aug 24, 2022
ad0ccbf
typo
jjsjann123 Aug 24, 2022
08f7d3b
Merge remote-tracking branch 'origin/devel' into HEAD
jjsjann123 Aug 24, 2022
309d12c
clangformat
jjsjann123 Aug 24, 2022
9f8bc95
more clean up from review comments
jjsjann123 Aug 26, 2022
26fafe6
lintrunner
jjsjann123 Aug 26, 2022
58f1daf
fixing build
jjsjann123 Aug 26, 2022
8bc1110
Add support for zero-input fusions
zasdfgbnm Aug 26, 2022
6d09cf7
Merge branch 'async_compilation' of github.com:csarofeen/pytorch into…
zasdfgbnm Aug 26, 2022
e8e9fe2
lintrunner
jjsjann123 Aug 26, 2022
1a01ea2
Merge branch 'async_compilation' of github.com:csarofeen/pytorch into…
zasdfgbnm Aug 26, 2022
37986c8
Merge branch 'devel' of github.com:csarofeen/pytorch into standalone-…
zasdfgbnm Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ TORCH_CUDA_CU_API WelfordResult Welford(

// TENSOR FACTORIES
TORCH_CUDA_CU_API TensorView* rand(
const std::vector<Int*>& shape,
const std::vector<Val*>& shape,
DataType dtype);

// UNARY OPERATIONS
Expand Down
15 changes: 6 additions & 9 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,21 @@ TEST_F(NVFuserTest, FusionRNGValidateWithCURand_CUDA) {
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified this test to use the nullary rand

fusion->addInput(tv0);
auto tv1 = randlike(tv0);
fusion->addOutput(tv1);
Int* size_val = IrBuilder::create<Int>();
fusion->addInput(size_val);
TensorView* tv0 = rand({size_val}, aten_to_data_type(dtype));
fusion->addOutput(tv0);

FusionExecutorCache fec(std::move(fusion_ptr));

auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
at::Tensor t0 = at::zeros({size}, options);

at::manual_seed(0);
auto cg_outputs = fec.runFusionWithInputs({t0});
auto cg_outputs = fec.runFusionWithInputs({size});
auto out = cg_outputs[0];

at::manual_seed(0);
auto ref = generate_uniform(size, dtype);

testValidate(fec.fusion(), {out}, {t0}, {ref}, __LINE__, __FILE__);
testValidate(fec.fusion(), {out}, {size}, {ref}, __LINE__, __FILE__);
}
}
}
Expand Down