Skip to content

Commit 1d9858c

Browse files
authored
Minor cleanup (#1992)
1 parent f262d9c commit 1d9858c

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

torch/csrc/jit/codegen/cuda/lower_shift.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,20 @@ void HaloInfo::build(TensorDomain* td) {
440440
} else {
441441
setHaloWidth(merge->out(), 0);
442442
}
443-
} else if (expr->getExprType().value() == ExprType::Swizzle2D) {
443+
} else if (auto swizzle = dynamic_cast<Swizzle2D*>(expr)) {
444444
// Assume no halo on swizzled domain for now.
445+
TORCH_INTERNAL_ASSERT(
446+
getExtent(swizzle->inX()) == nullptr,
447+
"Halo is not supported with swizzle. Halo-extended ID: ",
448+
swizzle->inX()->toString(),
449+
" used in ",
450+
swizzle->toString());
451+
TORCH_INTERNAL_ASSERT(
452+
getExtent(swizzle->inY()) == nullptr,
453+
"Halo is not supported with swizzle. Halo-extended ID: ",
454+
swizzle->inY()->toString(),
455+
" used in ",
456+
swizzle->toString());
445457
for (auto id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
446458
setHaloWidth(id, 0);
447459
}

torch/csrc/jit/codegen/cuda/transform_iter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void ReplayTransformations::handle(Swizzle2D* swizzle_2d) {
137137
auto id_in_y = swizzle_2d->inY();
138138

139139
// Make sure we have a corresponding entry in our map pointing to the ID we're
140-
// going to replay the split on
140+
// going to replay the swizzle on
141141
auto it_x = id_map_.find(id_in_x);
142142
auto it_y = id_map_.find(id_in_y);
143143

@@ -162,7 +162,7 @@ void ReplayTransformations::handle(Swizzle2D* swizzle_2d) {
162162
auto outs = std::make_pair(mapped_x, mapped_y);
163163

164164
if (replay_swizzle_) {
165-
// Replay the split onto mapped
165+
// Replay the swizzle onto mapped
166166
outs = IterDomain::swizzle(swizzle_2d->swizzleType(), mapped_x, mapped_y);
167167

168168
// Remove mapped from the leaf IDs

0 commit comments

Comments
 (0)