Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
7325643
CI, to our fork. (#145) (#303)
jjsjann123 Aug 18, 2020
47f6a57
Fix for issue #306 and #296 (#307)
csarofeen Aug 20, 2020
1793533
removing WAR of contig flag for broadcasting (#301)
jjsjann123 Aug 20, 2020
f12ab01
LSTM cell C++ test (#310)
csarofeen Aug 21, 2020
4ab4110
Fix predicate generation, there was a broken root map. (#311)
csarofeen Aug 21, 2020
ce9ac6e
Reorder expressions in a breadth-first order (#312)
naoyam Aug 21, 2020
9766713
Runtime overhead reduction pr (#309)
jjsjann123 Aug 21, 2020
1bf4028
Split the origin (def) links between Fusion IR and Kernel IR
tlemo Aug 21, 2020
0420efc
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Aug 21, 2020
907782b
Kernel IR refactoring: part 6 (#314)
tlemo Aug 21, 2020
52daa86
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Aug 21, 2020
3cc7ab7
Debug env disable fma (#315)
jjsjann123 Aug 21, 2020
e40aaca
Kernel IR refactoring: part 6.1 (#316)
tlemo Aug 21, 2020
6685712
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Aug 21, 2020
ffd7ba3
Fix kir::Sync::Sync() registration (#317)
tlemo Aug 21, 2020
6f94724
Add an IRPrinter handler for kir::TensorView (#318)
naoyam Aug 22, 2020
3136899
Dynamic Shared Memory (#304)
rdspring1 Aug 24, 2020
930cfe0
Detect computeAt causing mismatched TensorDomain (#327)
naoyam Aug 26, 2020
b7a1060
Additional tests on computeAt with minor refactoring (#331)
naoyam Aug 27, 2020
0fbfa90
Fix Inner Dimension Reductions for FP16 to perform just as well as TI…
kevinstephano Aug 28, 2020
81d4647
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Aug 31, 2020
c68fba8
Change pointwise scheduling to not generate multiple unrolled loops. …
csarofeen Aug 31, 2020
4194f49
Move IterVisitor derived classes from fusion.h to iter_visitor.h (#339)
csarofeen Aug 31, 2020
339e629
Update fusion parser test, remove printing from common consumer tests…
csarofeen Aug 31, 2020
2c1060a
Cleanup of hasBlockBroadcast (#340)
csarofeen Aug 31, 2020
60f9ed3
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 1, 2020
65b6469
Minor cleanup
tlemo Sep 1, 2020
f8f5062
Kernel IR: minor cleanup (#351)
tlemo Sep 1, 2020
d7540b6
cache on static input size/stride pr_0 (#326)
jjsjann123 Sep 1, 2020
82248bb
oops, resolving auto merge issue (#354)
jjsjann123 Sep 1, 2020
ada5150
Fixing CUDA fuser ci flag (#355)
jjsjann123 Sep 1, 2020
4ec6d5a
Enable Global Intermediate Buffers (#325)
rdspring1 Sep 2, 2020
23f00e1
Stateful evaluation (#347)
csarofeen Sep 2, 2020
c522c1f
Simple executor changes (#348)
csarofeen Sep 2, 2020
5f988ab
Fix for an invalid downcast in the Expression Evaluator (#358)
tlemo Sep 3, 2020
f97d304
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 3, 2020
a375394
Minor comment
tlemo Sep 3, 2020
92875d7
Multiple output reduction (#337)
jjsjann123 Sep 3, 2020
151cdb4
Minor cleanup
tlemo Sep 4, 2020
de78fd8
Cache eviction pr (#343)
jjsjann123 Sep 8, 2020
255e52e
Factor out the code generation and kernel state
tlemo Sep 8, 2020
38929bf
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 8, 2020
737a273
clang-format
tlemo Sep 8, 2020
d21d78f
Remove a false-positive assertion. (#372)
naoyam Sep 9, 2020
5a08221
Kernel IR: part 7 (#371)
tlemo Sep 9, 2020
0d48138
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 9, 2020
2f0c751
revert .build_profile addition
tlemo Sep 9, 2020
6a60779
Experimental doxygen support (#350)
tlemo Sep 9, 2020
1e007e6
IrPrinter cleanup
tlemo Sep 9, 2020
639747d
Kernel IR: Misc cleanup (#373)
tlemo Sep 9, 2020
470a687
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 9, 2020
be28cca
Checkpoint
tlemo Sep 10, 2020
e1fde84
Checkpoint
tlemo Sep 11, 2020
0841abd
Checkpoint
tlemo Sep 11, 2020
21aa708
Checkpoint
tlemo Sep 11, 2020
8f2b240
Checkpoint
tlemo Sep 11, 2020
5a8254a
Checkpoint
tlemo Sep 11, 2020
d5be86a
Improved formatting of the generated code
tlemo Sep 14, 2020
0f330c3
Dtype for reduction (#361)
jjsjann123 Sep 14, 2020
6043ced
Checkpoint
tlemo Sep 14, 2020
342e133
Move predication inside block/gridReduction functions (#376)
naoyam Sep 14, 2020
2cec423
Fix a few small issues
tlemo Sep 14, 2020
8006ffb
Small fix
tlemo Sep 14, 2020
83028a4
Fix scheduling for fp16 reductions (#370)
csarofeen Sep 14, 2020
b97292e
Generated code formatting tweaks
tlemo Sep 14, 2020
a334546
Small fix
tlemo Sep 14, 2020
2252bc8
Update testGPU_FusionParser
tlemo Sep 14, 2020
b02ac93
fix genPrologue()
tlemo Sep 14, 2020
103f690
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 15, 2020
573cc0a
Integrate the codegen changes from PR #376
tlemo Sep 15, 2020
bdb42a7
clang-format
tlemo Sep 15, 2020
2757382
Separate reduction schedule and heuristics (#378)
jjsjann123 Sep 15, 2020
530d6eb
Tiled GEMM example (#377)
naoyam Sep 15, 2020
385fb96
Kernel IR: Splitting CUDA codegen from IrPrinter (#379)
tlemo Sep 15, 2020
03b6348
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 15, 2020
71289aa
Tweak codegen formatting for binary operators
tlemo Sep 15, 2020
4e327e6
Kernel IR: small codegen formatting improvements (#381)
tlemo Sep 15, 2020
5ca3b1d
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 15, 2020
ee6a20a
CUDA Fuser instrumentation (#324)
tlemo Sep 17, 2020
9c1d7bd
Support for multi-threaded tracing (#385)
tlemo Sep 17, 2020
1ba7eeb
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 17, 2020
944dad5
Add _syncthreads for Write-After-Read Race (#383)
rdspring1 Sep 18, 2020
d3c7ce4
Checkpoint
tlemo Sep 19, 2020
e92b9ee
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 19, 2020
1103d1e
small comment
tlemo Sep 19, 2020
1c67154
Get a GEMM example with all bells and whistles (#368)
csarofeen Sep 22, 2020
cb3eb20
Checkpoint
tlemo Sep 22, 2020
dd8d1bc
Checkpoint
tlemo Sep 22, 2020
4fbdaf4
Checkpoint
tlemo Sep 22, 2020
2a26639
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 22, 2020
3116d94
clang-format
tlemo Sep 22, 2020
48d89e4
A few small fixes
tlemo Sep 22, 2020
8c2654c
Passkey idiom
tlemo Sep 22, 2020
f758e4c
Preparation for switching the KIR ownership from Fusion to Kernel
tlemo Sep 22, 2020
9ad9bcb
Minor cleanup
tlemo Sep 22, 2020
e6e3ed0
comments
tlemo Sep 22, 2020
7ebf97d
clang-format
tlemo Sep 22, 2020
7d72e8a
Fixes to reduction heuristic usage and caching (#392)
csarofeen Sep 22, 2020
2aeef7d
Kernel IR: Introducing kir::IrBuilder (#395)
tlemo Sep 23, 2020
e370581
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 23, 2020
97a7355
Adding value based root map analysis for loop nest sharing (#393)
shmsong Sep 23, 2020
364882b
Merge remote-tracking branch 'origin/20_8_18_devel' into kernel_ir
tlemo Sep 24, 2020
48a7f21
Fix CudaKernelGenerator::handle(const kir::Scope&) overload
tlemo Sep 24, 2020
9093ffa
apply repo changes for github
jjsjann123 Sep 25, 2020
8076183
Merge branch '20_9_25_devel' into kernel_ir
tlemo Sep 28, 2020
74e32ae
Fix issue 399 (#401)
tlemo Sep 28, 2020
6b44a05
Merge remote-tracking branch 'origin/20_9_25_devel' into kernel_ir
tlemo Sep 28, 2020
69748a7
WIP: move kernel IR nodes to Kernel
tlemo Sep 28, 2020
1a07658
simplify Val interface
tlemo Sep 28, 2020
67917a2
Fix transform replay (#404)
naoyam Sep 30, 2020
34edf10
WIP Checkpoint
tlemo Sep 30, 2020
ea73843
WIP Checkpoint
tlemo Oct 1, 2020
9d139c1
WIP Checkpoint
tlemo Oct 1, 2020
b546fc8
WIP Checkpoint
tlemo Oct 1, 2020
4ee4160
WIP Checkpoint
tlemo Oct 1, 2020
4b81d52
PE update from BailOut to CudaFusionGuard (#398)
jjsjann123 Oct 1, 2020
a11d393
WIP Checkpoint
tlemo Oct 1, 2020
4a39f82
WIP Checkpoint
tlemo Oct 1, 2020
d5bd0bf
WIP Checkpoint
tlemo Oct 2, 2020
a750cde
WIP Checkpoint
tlemo Oct 5, 2020
dcd334a
WIP Checkpoint
tlemo Oct 6, 2020
bea0c04
Consolodate namespaces to torch::jit::fuser::cuda (#407)
csarofeen Oct 6, 2020
17d361f
WIP Checkpoint
tlemo Oct 6, 2020
85e4d95
WIP Checkpoint
tlemo Oct 6, 2020
0722db4
Updating the debugCompileFusionFromStr() method of the Fusion Executo…
kevinstephano Oct 6, 2020
3cb62e1
WIP Checkpoint
tlemo Oct 6, 2020
779968e
WIP Checkpoint
tlemo Oct 6, 2020
aabe616
WIP: Checkpoint
tlemo Oct 7, 2020
bbe7c3f
Merge remote-tracking branch 'origin/20_9_25_devel' into kernel_ir
tlemo Oct 7, 2020
519795e
Fix merge issues
tlemo Oct 8, 2020
3ffc226
formatting fix
tlemo Oct 8, 2020
5647919
refactor kir::ExpressionEvaluator
tlemo Oct 8, 2020
145d362
Value definition link
tlemo Oct 8, 2020
690d3bc
fixing a few small issues
tlemo Oct 8, 2020
2ca2983
codegeneration fixes
tlemo Oct 8, 2020
d5398fe
minor cleanup
tlemo Oct 8, 2020
d624f19
minor refactor
tlemo Oct 8, 2020
ea5748d
WIP: initReduce()
tlemo Oct 8, 2020
b14645d
WIP - reductions mostly work
tlemo Oct 9, 2020
f1520ca
WIP checkpoint
tlemo Oct 9, 2020
00bd836
moved predicates to kir::Expr
tlemo Oct 9, 2020
a6471f7
WIP Checkpoint
tlemo Oct 9, 2020
f97944c
Kernel IR printer (#415)
tlemo Oct 12, 2020
2eefd84
Small kir::IrPrinter improvements (#417)
tlemo Oct 12, 2020
0f5b1b1
Merge remote-tracking branch 'origin/20_9_25_devel' into kernel_ir
tlemo Oct 13, 2020
805784c
Merge in the new kir::IrPrinter
tlemo Oct 13, 2020
09475c8
WIP checkpoint
tlemo Oct 13, 2020
c27abdf
WIP checkpoint
tlemo Oct 13, 2020
1fc46e9
Persistent Kernel Examples + Improvements (#402)
rdspring1 Oct 13, 2020
cc78bc9
Make separate tests as separate test functions (#419)
naoyam Oct 13, 2020
6a674b9
WIP checkpoint
tlemo Oct 14, 2020
6322caf
Replace pragma with proper using statements (#420)
naoyam Oct 14, 2020
4243378
Bug fix (#422)
naoyam Oct 14, 2020
7984c02
WIP checkpoint
tlemo Oct 14, 2020
91fc942
WIP checkpoint
tlemo Oct 14, 2020
b80f0e2
WIP checkpoint
tlemo Oct 14, 2020
414b9e1
WIP checkpoint
tlemo Oct 14, 2020
182358c
minor fix
tlemo Oct 14, 2020
51d4b19
minor cleanup
tlemo Oct 14, 2020
44def35
WIP checkpoint
tlemo Oct 14, 2020
931e5b9
minor fixes
tlemo Oct 14, 2020
807fa9f
minor fix
tlemo Oct 15, 2020
8ed9418
reenable UnrollPass
tlemo Oct 15, 2020
f68c526
ThreadPredicateMap::print()
tlemo Oct 15, 2020
6dc9230
minor cleanup
tlemo Oct 15, 2020
ee90d8b
minor cleanup
tlemo Oct 15, 2020
e78b6e7
fix predication
tlemo Oct 15, 2020
73f2ffb
Fix a dangerous typo
tlemo Oct 15, 2020
258e21d
small fix
tlemo Oct 15, 2020
bc78fd5
ExpressionEvaluator::isConst
tlemo Oct 16, 2020
7bee611
Small fixes
tlemo Oct 16, 2020
36bad49
Fix bindKernelInputs
tlemo Oct 16, 2020
f1c1900
Update FusionParser_CUDA
tlemo Oct 16, 2020
b6eb6ad
Temporary fixes
tlemo Oct 16, 2020
3929002
Revert ThreadPredicateMap::print()
tlemo Oct 16, 2020
7966c0a
Small cleanup & comments
tlemo Oct 16, 2020
44bcbca
Merge remote-tracking branch 'origin/20_9_25_devel' into kernel_ir
tlemo Oct 16, 2020
6cb9639
merge lower_alias_memory
tlemo Oct 16, 2020
5cc4e12
kir::IrPrinter support for Allocate::alias()
tlemo Oct 16, 2020
de08d16
minor cleanup
tlemo Oct 16, 2020
15ff731
WIP checkpoint
tlemo Oct 16, 2020
9c1c6a3
small cleanup
tlemo Oct 16, 2020
9dc3980
clang-format
tlemo Oct 16, 2020
c553a8c
small fixes
tlemo Oct 16, 2020
4823173
Incorporating review feedback
tlemo Oct 20, 2020
f2e6bfa
small fix
tlemo Oct 20, 2020
2382b7b
clang-format
tlemo Oct 20, 2020
ee9a046
Merge remote-tracking branch 'origin/20_10_20_devel' into exp_develop
tlemo Oct 21, 2020
91904dc
Sync up with 20_10_20_devel
tlemo Oct 21, 2020
898a430
Merge branch 'exp_develop' into kernel_ir_part9
tlemo Oct 21, 2020
6af5fde
Please clang-tidy
tlemo Oct 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir_builder.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir_printer.cpp
Expand Down
52 changes: 24 additions & 28 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1126,25 +1126,25 @@ TEST(NVFuserTest, FusionParser_CUDA) {
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
float T2[1];
if ((((((blockIdx.x * 1) + (1 - 1)) * 128) + threadIdx.x) < T0.size[0])) {
for(size_t i6 = 0; i6 < 1; ++i6) {
T2[i6]
= T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
* T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
= T2[i6]
* T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
for(size_t ki25 = 0; ki25 < 1; ++ki25) {
T2[ki25]
= T0[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)]
* T1[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)];
T3[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)]
= T2[ki25]
* T0[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)];
}
} else {
for(size_t i6 = 0; i6 < 1; ++i6) {
if ((((((blockIdx.x * 1) + i6) * 128) + threadIdx.x) < T0.size[0])) {
T2[i6]
= T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
* T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
for(size_t ki25 = 0; ki25 < 1; ++ki25) {
if ((((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x) < T0.size[0])) {
T2[ki25]
= T0[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)]
* T1[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)];
}
if ((((((blockIdx.x * 1) + i6) * 128) + threadIdx.x) < T0.size[0])) {
T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
= T2[i6]
* T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
if ((((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x) < T0.size[0])) {
T3[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)]
= T2[ki25]
* T0[((((blockIdx.x * 1) + ki25) * 128) + threadIdx.x)];
}
}
}
Expand Down Expand Up @@ -5700,7 +5700,7 @@ TEST(NVFuserTest, FusionSmem_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
}

TEST(NVFuserTest, FusionSmemReduce_CUDA) {
Expand Down Expand Up @@ -5750,8 +5750,7 @@ TEST(NVFuserTest, FusionSmemReduce_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
}

TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
Expand Down Expand Up @@ -5814,7 +5813,7 @@ TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
}

TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
Expand Down Expand Up @@ -5900,7 +5899,7 @@ TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
}

TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) {
Expand Down Expand Up @@ -6413,7 +6412,7 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
}

TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
Expand Down Expand Up @@ -6471,8 +6470,7 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
}

TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
Expand Down Expand Up @@ -6529,8 +6527,7 @@ TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(22) == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
}

TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
Expand Down Expand Up @@ -6655,8 +6652,7 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
aten_C.allclose(C_fuser, 1e-5, 1e-5),
"Error of: ",
aten_C.sub(C_fuser).abs().max());
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(41) == 1);
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
}

TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) {
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ libtorch_cuda_sources = [
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
"torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp",
"torch/csrc/jit/codegen/cuda/kernel_ir.cpp",
"torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp",
"torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp",
Expand Down
Loading