@@ -386,7 +386,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
386
386
ContigIDs::ContigIDs (
387
387
const std::vector<IterDomain*>& ids,
388
388
const std::vector<IterDomain*>& root_domain,
389
- const std::vector<bool >& root_contiguity,
389
+ const std::vector<c10::optional< bool > >& root_contiguity,
390
390
const std::unordered_set<IterDomain*>& final_ids,
391
391
const std::unordered_map<IterDomain*, Val*>& index_map,
392
392
const std::unordered_set<Split*>& divisible_splits,
@@ -419,7 +419,7 @@ ContigIDs::ContigIDs(
419
419
ContigIDs::ContigIDs (
420
420
const std::vector<IterDomain*>& ids,
421
421
const std::vector<IterDomain*>& root_domain,
422
- const std::vector<bool >& root_contiguity,
422
+ const std::vector<c10::optional< bool > >& root_contiguity,
423
423
const std::unordered_set<IterDomain*>& final_ids,
424
424
const std::unordered_map<IterDomain*, Val*>& index_map,
425
425
const std::unordered_set<Split*>& divisible_splits,
@@ -458,17 +458,16 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
458
458
}
459
459
460
460
TORCH_INTERNAL_ASSERT (
461
- TensorDomain::noBroadcasts (root_domain_).size () ==
462
- root_contiguity_.size (),
461
+ root_domain_.size () == root_contiguity_.size (),
463
462
" Arguments don't match " ,
464
- TensorDomain::noBroadcasts ( root_domain_) .size (),
463
+ root_domain_.size (),
465
464
" != " ,
466
465
root_contiguity_.size ());
467
466
468
- int no_broadcast_i = 0 ;
469
467
for (const auto root_domain_i : c10::irange (root_domain_.size ())) {
470
468
auto root_domain_id = root_domain_.at (root_domain_i)->as <IterDomain>();
471
469
if (root_domain_id->isBroadcast ()) {
470
+ TORCH_INTERNAL_ASSERT (!root_contiguity_.at (root_domain_i).has_value ());
472
471
continue ;
473
472
}
474
473
root_to_indexed_id_[root_domain_id] = root_domain_id;
@@ -479,14 +478,13 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
479
478
// rfactor root domains, which should just return "zero"
480
479
// RootAxisInfo. This should be safe as no rfactor tensor should
481
480
// need halo.
482
- if (root_contiguity_.at (no_broadcast_i ) &&
481
+ if (* root_contiguity_.at (root_domain_i ) &&
483
482
!halo_info_->getRootAxisInfo (root_domain_id).hasHalo () &&
484
483
root_domain_id->getIterType () != IterType::GatherScatter) {
485
484
contig_ids_.emplace (root_domain_id);
486
485
is_contig_root_.at (root_domain_id) = true ;
487
486
within_contig_ids_[root_domain_id] = std::unordered_set<IterDomain*>();
488
487
}
489
- no_broadcast_i++;
490
488
}
491
489
492
490
if (!contig_ids_.empty ()) {
@@ -540,10 +538,10 @@ void ContigIDs::handle(Merge* merge) {
540
538
bool is_indexing_pass = !ignore_consistent_ordering_;
541
539
542
540
IterDomain* last_root = nullptr ;
543
- int no_broadcast_i = 0 ;
544
541
for (auto root_id_i : c10::irange (root_domain_.size ())) {
545
542
auto root_id = root_domain_[root_id_i];
546
543
if (root_id->isBroadcast ()) {
544
+ TORCH_INTERNAL_ASSERT (!root_contiguity_.at (root_id_i).has_value ());
547
545
continue ;
548
546
}
549
547
if (root_ids.has (root_id)) {
@@ -556,14 +554,13 @@ void ContigIDs::handle(Merge* merge) {
556
554
// If we're computing predicates (ignore_consistent_ordering_==true),
557
555
// then we don't have this same constraint, we can just ignore
558
556
// contiguity of the roots all together.
559
- if (!root_contiguity_.at (no_broadcast_i ) && is_indexing_pass) {
557
+ if (!* root_contiguity_.at (root_id_i ) && is_indexing_pass) {
560
558
if (!root_ids.empty ()) {
561
559
return ;
562
560
}
563
561
}
564
562
last_root = root_id;
565
563
}
566
- no_broadcast_i++;
567
564
}
568
565
569
566
// If there's a non_divisible split in the history of merge->out then it can't
0 commit comments