Skip to content

Commit 9eb4c20

Browse files
authored
Change contiguity into std::vector<c10::optional<bool>> (#2569)
1 parent 3c4b3da commit 9eb4c20

37 files changed

+319
-364
lines changed

third_party/nvfuser/benchmark/bert.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,13 @@ static void setupBiasDropoutAddLayernormBwd1(Fusion* fusion, DataType dtype) {
349349
TensorView* tv3 = TensorViewBuilder()
350350
.ndims(3)
351351
.dtype(dtype)
352-
.contiguity({true, true})
352+
.contiguity({true, true, c10::nullopt})
353353
.shape({-1, -1, 1})
354354
.build();
355355
TensorView* tv4 = TensorViewBuilder()
356356
.ndims(3)
357357
.dtype(dtype)
358-
.contiguity({true, true})
358+
.contiguity({true, true, c10::nullopt})
359359
.shape({-1, -1, 1})
360360
.build();
361361

@@ -457,7 +457,7 @@ static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) {
457457
TensorView* tv4 = TensorViewBuilder()
458458
.ndims(3)
459459
.dtype(dtype)
460-
.contiguity({true, true})
460+
.contiguity({true, true, c10::nullopt})
461461
.shape({-1, -1, 1})
462462
.build();
463463
TensorView* tv5 = makeContigTensor(1, dtype);

third_party/nvfuser/benchmark/layer_norm_backward.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) {
2727
auto bias = makeContigTensor(1, dtype);
2828

2929
auto mean = TensorViewBuilder()
30-
.contiguity({false})
30+
.contiguity({false, c10::nullopt})
3131
.shape({-1, 1})
3232
.dtype(DataType::Float)
3333
.build();
3434
auto rstd = TensorViewBuilder()
35-
.contiguity({false})
35+
.contiguity({false, c10::nullopt})
3636
.shape({-1, 1})
3737
.dtype(DataType::Float)
3838
.build();

third_party/nvfuser/benchmark/rms_norm_backward.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) {
2727
auto input = makeContigTensor(3, dtype);
2828
auto weight = makeContigTensor(1, dtype);
2929
auto rstd = TensorViewBuilder()
30-
.contiguity({false, false})
30+
.contiguity({false, false, c10::nullopt})
3131
.shape({-1, -1, 1})
3232
.dtype(dtype)
3333
.build();

third_party/nvfuser/benchmark/scale_bias_relu.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@ static void setupSBR(Fusion* fusion, DataType dtype) {
2020
std::vector<int64_t> bcast_shape(kNumberOfDims, 1);
2121
bcast_shape[bcast_shape.size() - 1] = -1;
2222

23-
std::vector<bool> bcast_contig(1, true);
24-
2523
auto x = makeContigTensor(kNumberOfDims, dtype);
2624

2725
auto scale = TensorViewBuilder()
28-
.contiguity(bcast_contig)
2926
.shape(bcast_shape)
27+
.contiguity(true)
3028
.dtype(dtype)
3129
.build();
3230

3331
auto bias = TensorViewBuilder()
34-
.contiguity(bcast_contig)
3532
.shape(bcast_shape)
33+
.contiguity(true)
3634
.dtype(dtype)
3735
.build();
3836

third_party/nvfuser/benchmark/timm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ static void setup_vit_base_patch16_224_bcast7(Fusion* fusion, void* null) {
1616
auto t3 = TensorViewBuilder()
1717
.shape({-1, -1, 1})
1818
.dtype(DataType::Float)
19-
.contiguity({true, true})
19+
.contiguity({true, true, c10::nullopt})
2020
.build();
2121
auto t4 = TensorViewBuilder()
2222
.shape({-1, -1, 1})
2323
.dtype(DataType::Float)
24-
.contiguity({true, true})
24+
.contiguity({true, true, c10::nullopt})
2525
.build();
2626
auto t7 = makeContigTensor(3, DataType::Half);
2727

@@ -538,14 +538,14 @@ static void setup_vit_base_patch16_224_LN_BWD(Fusion* fusion, void* null) {
538538
auto t5 = TensorViewBuilder()
539539
.shape({-1, -1, 1})
540540
.dtype(DataType::Float)
541-
.contiguity({true, true})
541+
.contiguity({true, true, c10::nullopt})
542542
.build();
543543
fusion->addInput(t5);
544544

545545
auto t6 = TensorViewBuilder()
546546
.shape({-1, -1, 1})
547547
.dtype(DataType::Float)
548-
.contiguity({true, true})
548+
.contiguity({true, true, c10::nullopt})
549549
.build();
550550
fusion->addInput(t6);
551551

third_party/nvfuser/benchmark/utils.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,7 @@ TensorView* makeSymbolicTensor(size_t ndims, DataType dtype) {
145145
}
146146

147147
TensorView* makeContigTensor(size_t ndims, DataType dtype) {
148-
return TensorViewBuilder()
149-
.ndims(ndims)
150-
.dtype(dtype)
151-
.contiguity(std::vector<bool>(ndims, true))
152-
.build();
148+
return TensorViewBuilder().ndims(ndims).dtype(dtype).contiguity(true).build();
153149
}
154150

155151
TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
@@ -159,18 +155,7 @@ TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
159155
TensorView* makeContigConcreteTensor(
160156
std::vector<int64_t> shape,
161157
DataType dtype) {
162-
std::vector<bool> contiguity;
163-
for (auto s : shape) {
164-
if (s == 1) {
165-
continue;
166-
}
167-
contiguity.push_back(true);
168-
}
169-
return TensorViewBuilder()
170-
.shape(shape)
171-
.dtype(dtype)
172-
.contiguity(contiguity)
173-
.build();
158+
return TensorViewBuilder().shape(shape).dtype(dtype).contiguity(true).build();
174159
}
175160

176161
void runBenchmarkIterations(

third_party/nvfuser/csrc/contiguity.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
386386
ContigIDs::ContigIDs(
387387
const std::vector<IterDomain*>& ids,
388388
const std::vector<IterDomain*>& root_domain,
389-
const std::vector<bool>& root_contiguity,
389+
const std::vector<c10::optional<bool>>& root_contiguity,
390390
const std::unordered_set<IterDomain*>& final_ids,
391391
const std::unordered_map<IterDomain*, Val*>& index_map,
392392
const std::unordered_set<Split*>& divisible_splits,
@@ -419,7 +419,7 @@ ContigIDs::ContigIDs(
419419
ContigIDs::ContigIDs(
420420
const std::vector<IterDomain*>& ids,
421421
const std::vector<IterDomain*>& root_domain,
422-
const std::vector<bool>& root_contiguity,
422+
const std::vector<c10::optional<bool>>& root_contiguity,
423423
const std::unordered_set<IterDomain*>& final_ids,
424424
const std::unordered_map<IterDomain*, Val*>& index_map,
425425
const std::unordered_set<Split*>& divisible_splits,
@@ -458,17 +458,16 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
458458
}
459459

460460
TORCH_INTERNAL_ASSERT(
461-
TensorDomain::noBroadcasts(root_domain_).size() ==
462-
root_contiguity_.size(),
461+
root_domain_.size() == root_contiguity_.size(),
463462
"Arguments don't match ",
464-
TensorDomain::noBroadcasts(root_domain_).size(),
463+
root_domain_.size(),
465464
" != ",
466465
root_contiguity_.size());
467466

468-
int no_broadcast_i = 0;
469467
for (const auto root_domain_i : c10::irange(root_domain_.size())) {
470468
auto root_domain_id = root_domain_.at(root_domain_i)->as<IterDomain>();
471469
if (root_domain_id->isBroadcast()) {
470+
TORCH_INTERNAL_ASSERT(!root_contiguity_.at(root_domain_i).has_value());
472471
continue;
473472
}
474473
root_to_indexed_id_[root_domain_id] = root_domain_id;
@@ -479,14 +478,13 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
479478
// rfactor root domains, which should just return "zero"
480479
// RootAxisInfo. This should be safe as no rfactor tensor should
481480
// need halo.
482-
if (root_contiguity_.at(no_broadcast_i) &&
481+
if (*root_contiguity_.at(root_domain_i) &&
483482
!halo_info_->getRootAxisInfo(root_domain_id).hasHalo() &&
484483
root_domain_id->getIterType() != IterType::GatherScatter) {
485484
contig_ids_.emplace(root_domain_id);
486485
is_contig_root_.at(root_domain_id) = true;
487486
within_contig_ids_[root_domain_id] = std::unordered_set<IterDomain*>();
488487
}
489-
no_broadcast_i++;
490488
}
491489

492490
if (!contig_ids_.empty()) {
@@ -540,10 +538,10 @@ void ContigIDs::handle(Merge* merge) {
540538
bool is_indexing_pass = !ignore_consistent_ordering_;
541539

542540
IterDomain* last_root = nullptr;
543-
int no_broadcast_i = 0;
544541
for (auto root_id_i : c10::irange(root_domain_.size())) {
545542
auto root_id = root_domain_[root_id_i];
546543
if (root_id->isBroadcast()) {
544+
TORCH_INTERNAL_ASSERT(!root_contiguity_.at(root_id_i).has_value());
547545
continue;
548546
}
549547
if (root_ids.has(root_id)) {
@@ -556,14 +554,13 @@ void ContigIDs::handle(Merge* merge) {
556554
// If we're computing predicates (ignore_consistent_ordering_==true),
557555
// then we don't have this same constraint, we can just ignore
558556
// contiguity of the roots all together.
559-
if (!root_contiguity_.at(no_broadcast_i) && is_indexing_pass) {
557+
if (!*root_contiguity_.at(root_id_i) && is_indexing_pass) {
560558
if (!root_ids.empty()) {
561559
return;
562560
}
563561
}
564562
last_root = root_id;
565563
}
566-
no_broadcast_i++;
567564
}
568565

569566
// If there's a non_divisible split in the history of merge->out then it can't

third_party/nvfuser/csrc/contiguity.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class ContigIDs : public OptInDispatch {
157157
ContigIDs(
158158
const std::vector<IterDomain*>& ids,
159159
const std::vector<IterDomain*>& root_domain,
160-
const std::vector<bool>& root_contiguity,
160+
const std::vector<c10::optional<bool>>& root_contiguity,
161161
const std::unordered_set<IterDomain*>& final_ids,
162162
const std::unordered_map<IterDomain*, Val*>& index_map,
163163
const std::unordered_set<Split*>& divisible_splits,
@@ -188,7 +188,7 @@ class ContigIDs : public OptInDispatch {
188188
ContigIDs(
189189
const std::vector<IterDomain*>& ids,
190190
const std::vector<IterDomain*>& root_domain,
191-
const std::vector<bool>& root_contiguity,
191+
const std::vector<c10::optional<bool>>& root_contiguity,
192192
const std::unordered_set<IterDomain*>& final_ids,
193193
const std::unordered_map<IterDomain*, Val*>& index_map,
194194
const std::unordered_set<Split*>& divisible_splits,
@@ -264,7 +264,7 @@ class ContigIDs : public OptInDispatch {
264264
//! Root domains to analyze contiguity
265265
const std::vector<IterDomain*>& root_domain_;
266266
//! Contiguity of root_domain_
267-
const std::vector<bool>& root_contiguity_;
267+
const std::vector<c10::optional<bool>>& root_contiguity_;
268268
//! Domains where indexing/predicates cannot be done with their
269269
//! consumers domains
270270
const std::unordered_set<IterDomain*>& final_ids_;

third_party/nvfuser/csrc/executor_kernel_arg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/jit/ir/ir.h>
77
#include <type.h>
88
#include <array>
9+
#include <optional>
910

1011
namespace nvfuser {
1112

third_party/nvfuser/csrc/fusion_segmenter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ TensorView* castIntermediateValueInCompleteFusion(
780780
return IrBuilder::create<TensorView>(
781781
IrBuilder::create<TensorDomain>(
782782
new_root_domain,
783-
TensorDomain::getContiguousContiguity(new_root_domain)),
783+
TensorDomain::getContiguityFilledWith(new_root_domain, true)),
784784
data_type);
785785
};
786786

0 commit comments

Comments
 (0)