Skip to content

Commit 292ebef

Browse files
authored
Some misc swizzle changes (#2138)
1 parent 19e5af7 commit 292ebef

15 files changed

+815
-383
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ libtorch_cuda_core_sources = [
751751
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
752752
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
753753
"torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp",
754+
"torch/csrc/jit/codegen/cuda/swizzle.cpp",
754755
"torch/csrc/jit/codegen/cuda/type_inference.cpp",
755756
"torch/csrc/jit/codegen/cuda/type_promotion.cpp",
756757
"torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp",

test/cpp/jit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ if(USE_CUDA)
102102
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu1.cpp)
103103
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu2.cpp)
104104
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu3.cpp)
105+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_swizzle.cpp)
105106
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp)
106107
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp)
107108
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ NVFUSER_DEFINE_BINARY_FLOAT_OP(atan2, Atan2)
948948
}
949949

950950
// Integer binary ops
951+
NVFUSER_DEFINE_BINARY_CAST_OP(cpp_div, Div)
951952
NVFUSER_DEFINE_BINARY_CAST_OP(mod, Mod)
952953
NVFUSER_DEFINE_BINARY_CAST_OP(ceilDiv, CeilDiv)
953954
NVFUSER_DEFINE_BINARY_CAST_OP(add, Add)

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ TORCH_CUDA_CU_API Val* div(Val* v1, Val* v2);
350350
TORCH_CUDA_CU_API TensorView* div(TensorView* v1, Val* v2);
351351
TORCH_CUDA_CU_API TensorView* div(Val* v1, TensorView* v2);
352352
TORCH_CUDA_CU_API TensorView* div(TensorView* v1, TensorView* v2);
353+
// cpp_div: similar to div, but don't promote to float
354+
TORCH_CUDA_CU_API Val* cpp_div(Val* v1, Val* v2);
355+
TORCH_CUDA_CU_API TensorView* cpp_div(TensorView* v1, Val* v2);
356+
TORCH_CUDA_CU_API TensorView* cpp_div(Val* v1, TensorView* v2);
357+
TORCH_CUDA_CU_API TensorView* cpp_div(TensorView* v1, TensorView* v2);
353358
// fmod
354359
TORCH_CUDA_CU_API Val* fmod(Val* v1, Val* v2);
355360
TORCH_CUDA_CU_API TensorView* fmod(TensorView* v1, Val* v2);

torch/csrc/jit/codegen/cuda/dynamic_type.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,18 @@ class TORCH_CUDA_CU_API IntOrDouble {
132132
}
133133
TORCH_INTERNAL_ASSERT(false);
134134
}
135+
IntOrDouble operator^(const IntOrDouble& other) const {
136+
if (is_int() && other.is_int()) {
137+
return IntOrDouble(as<int64_t>() ^ other.as<int64_t>());
138+
}
139+
TORCH_INTERNAL_ASSERT(false);
140+
}
141+
IntOrDouble operator^(int64_t other) const {
142+
if (is_int()) {
143+
return IntOrDouble(as<int64_t>() ^ other);
144+
}
145+
TORCH_INTERNAL_ASSERT(false);
146+
}
135147

136148
#define DEFINE_COMPARE_OP(op) \
137149
bool operator op(const IntOrDouble& other) const { \

torch/csrc/jit/codegen/cuda/expr_evaluator.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ void ExpressionEvaluator::handle(UnaryOp* uop) {
146146
break;
147147
default:
148148
TORCH_CHECK(
149-
!"Unexpected operator type ",
149+
false,
150+
"Unexpected operator type ",
150151
uop->getUnaryOpType(),
151152
" in ",
152153
uop->toString());
@@ -190,8 +191,16 @@ void ExpressionEvaluator::handle(BinaryOp* bop) {
190191
case BinaryOpType::Min:
191192
known_values_[bop->out()] = min(*lhs, *rhs);
192193
break;
194+
case BinaryOpType::Xor:
195+
known_values_[bop->out()] = *lhs ^ *rhs;
196+
break;
193197
default:
194-
TORCH_CHECK(!"Unexpected operator type");
198+
TORCH_CHECK(
199+
false,
200+
"Unexpected operator type: ",
201+
bop->getBinaryOpType(),
202+
" in ",
203+
bop->toString());
195204
}
196205
}
197206
}

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
2020
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
2121
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
22+
#include <torch/csrc/jit/codegen/cuda/swizzle.h>
2223
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
2324
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
2425

@@ -552,18 +553,14 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) {
552553
// Generate integer swizzle math if the
553554
// swizzle is activated. See also
554555
// [Note on swizzle mode].
555-
556-
auto out_pair = IrBuilder::swizzle2DIntExpr(
556+
std::pair<Val*, Val*> swizzled_index = dispatchSwizzle(
557+
swizzle_2d->swizzleType(),
557558
out_x_ind,
558559
out_y_ind,
559560
getExtent(out_x_id),
560-
getExtent(out_y_id),
561-
swizzle_2d->swizzleType());
562-
563-
index_map_[in_x_id] =
564-
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
565-
index_map_[in_y_id] =
566-
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
561+
getExtent(out_y_id));
562+
index_map_[in_x_id] = swizzled_index.first;
563+
index_map_[in_y_id] = swizzled_index.second;
567564
}
568565
}
569566

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,8 @@ struct ReplaceValInIndexVal : public OptInDispatch {
941941

942942
void handle(Val* val) override {
943943
TORCH_INTERNAL_ASSERT(
944-
val->isA<Int>() || val->isA<NamedScalar>() || val->isA<kir::IntPair>(),
944+
val->isA<Int>() || val->isA<Bool>() || val->isA<NamedScalar>() ||
945+
val->isA<kir::IntPair>(),
945946
"Invalid Val type: ",
946947
val->toString());
947948

@@ -960,6 +961,7 @@ struct ReplaceValInIndexVal : public OptInDispatch {
960961
switch (def->etype()) {
961962
case ExprType::UnaryOp:
962963
case ExprType::BinaryOp:
964+
case ExprType::TernaryOp:
963965
case ExprType::Swizzle2DInt:
964966
case ExprType::PairSelect:
965967
handle(val->definition());
@@ -978,7 +980,10 @@ struct ReplaceValInIndexVal : public OptInDispatch {
978980
void handle(UnaryOp* uop) override {
979981
handle(uop->in());
980982
auto inp = last_visited_val_;
981-
TORCH_INTERNAL_ASSERT(uop->out()->isA<Int>());
983+
TORCH_INTERNAL_ASSERT(
984+
uop->out()->isA<Int>() || uop->out()->isA<Bool>(),
985+
"Unknown output type for expr ",
986+
uop->toInlineString());
982987
auto out = IrBuilder::create<Int>(c10::nullopt);
983988
IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, inp);
984989
last_visited_val_ = out;
@@ -990,12 +995,32 @@ struct ReplaceValInIndexVal : public OptInDispatch {
990995
auto lhs = last_visited_val_;
991996
handle(bop->rhs());
992997
auto rhs = last_visited_val_;
993-
TORCH_INTERNAL_ASSERT(bop->out()->isA<Int>());
998+
TORCH_INTERNAL_ASSERT(
999+
bop->out()->isA<Int>() || bop->out()->isA<Bool>(),
1000+
"Unknown output type for expr ",
1001+
bop->toInlineString());
9941002
auto out = IrBuilder::create<Int>(c10::nullopt);
9951003
IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs);
9961004
last_visited_val_ = out;
9971005
}
9981006

1007+
// Clone expression after recurisvely replacing inputs
1008+
void handle(TernaryOp* top) override {
1009+
handle(top->in1());
1010+
auto in1 = last_visited_val_;
1011+
handle(top->in2());
1012+
auto in2 = last_visited_val_;
1013+
handle(top->in3());
1014+
auto in3 = last_visited_val_;
1015+
TORCH_INTERNAL_ASSERT(
1016+
top->out()->isA<Int>() || top->out()->isA<Bool>(),
1017+
"Unknown output type for expr ",
1018+
top->toInlineString());
1019+
auto out = IrBuilder::create<Int>(c10::nullopt);
1020+
IrBuilder::create<TernaryOp>(top->getTernaryOpType(), out, in1, in2, in3);
1021+
last_visited_val_ = out;
1022+
}
1023+
9991024
// Clone expression after recurisvely replacing inputs
10001025
void handle(kir::Swizzle2DInt* swizzle_2d) override {
10011026
handle(swizzle_2d->inX());

torch/csrc/jit/codegen/cuda/lower_validation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,8 @@ class VectorizeValidator : public OptInDispatch {
312312
}
313313

314314
void handle(Swizzle2D* swizzle) final {
315-
if (swizzle->outX() == vectorized_id_ || swizzle->inX() == vectorized_id_) {
315+
if (swizzle->outX() == vectorized_id_ || swizzle->inX() == vectorized_id_ ||
316+
swizzle->outY() == vectorized_id_ || swizzle->inY() == vectorized_id_) {
316317
// Do not (yet) allow vectorization across any swizzled id.
317318
is_valid = false;
318319
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#include <torch/csrc/jit/codegen/cuda/swizzle.h>
2+
3+
#include <torch/csrc/jit/codegen/cuda/arith.h>
4+
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
5+
6+
namespace torch {
7+
namespace jit {
8+
namespace fuser {
9+
namespace cuda {
10+
namespace swizzles {
11+
12+
// ------------------------------------------------------------
13+
// Swizzle Definitions
14+
// for each swizzle name:
15+
// un(Swizzle Name) e.g. unZShape is the inverse of ZShape,
16+
// (unswizzle is needed for inlining and is currently not actively used.)
17+
// ------------------------------------------------------------
18+
19+
// Unit Z swizzle:
20+
// Alternate directions of Y dimension:
21+
// 1 2 3 1 2 3
22+
// 4 5 6 => 6 5 4
23+
// 7 8 9 7 8 9
24+
std::pair<Val*, Val*> ZShape(Val* x, Val* y, Val* size_y) {
25+
auto zero = x->fusion()->zeroVal();
26+
auto one = x->fusion()->oneVal();
27+
auto two = IrBuilder::create<Int>(2);
28+
return {x, where(eq(mod(x, two), zero), y, sub(sub(size_y, one), y))};
29+
}
30+
31+
// ZShape is inverse of itself
32+
std::pair<Val*, Val*> unZShape(Val* x, Val* y, Val* size_y) {
33+
return ZShape(x, y, size_y);
34+
}
35+
36+
// Block cyclic Xor swizzle: (bank conflict removal)
37+
// Apply cyclic Xor within blocks:
38+
// Example: cyclic Xor
39+
// 1 2 3 4 1 2 3 4
40+
// 5 6 7 8 6 5 8 7
41+
// 9 10 11 12 => 11 12 9 10
42+
// 13 14 15 16 16 15 14 13
43+
std::pair<Val*, Val*> Xor(Val* x, Val* y) {
44+
// Need to validate in swizzle configuration:
45+
// size_x == size_y
46+
return {x, bitwise_xor(x, y)};
47+
}
48+
49+
// Xor is inverse of itself
50+
std::pair<Val*, Val*> unXor(Val* x, Val* y) {
51+
return Xor(x, y);
52+
}
53+
54+
// Block cyclic shift swizzle: (bank conflict removal)
55+
// Apply cyclic shift within blocks:
56+
// Example: cyclic shift
57+
// 1 2 3 4 1 2 3 4
58+
// 5 6 7 8 8 5 6 7
59+
// 9 10 11 12 => 11 12 9 10
60+
// 13 14 15 16 14 15 16 13
61+
std::pair<Val*, Val*> CyclicShift(Val* x, Val* y, Val* size_x) {
62+
return {x, mod(add(x, y), size_x)};
63+
}
64+
65+
std::pair<Val*, Val*> unCyclicShift(Val* x, Val* y, Val* size_x) {
66+
return {x, mod(sub(add(size_x, y), x), size_x)};
67+
}
68+
69+
// Scatter swizzle:
70+
// Corresponds to the data layout out of ldmatrix intrinsic.
71+
// supported dimensions are : 8x4, 16x4, 32x4
72+
std::pair<Val*, Val*> Scatter(Val* x, Val* y, int size_x) {
73+
TORCH_CHECK(
74+
size_x == 8 || size_x == 16 || size_x == 32,
75+
"Unsupported Scatter swizzle size");
76+
Val* size_x_val = IrBuilder::create<Int>(size_x);
77+
auto four = IrBuilder::create<Int>(4);
78+
return {cpp_div(add(mul(y, size_x_val), x), four), mod(x, four)};
79+
}
80+
81+
std::pair<Val*, Val*> unScatter(Val* x, Val* y, int size_x) {
82+
TORCH_CHECK(
83+
size_x == 8 || size_x == 16 || size_x == 32,
84+
"Unsupported Scatter swizzle size");
85+
Val* size_x_div_4 = IrBuilder::create<Int>(size_x / 4);
86+
auto four = IrBuilder::create<Int>(4);
87+
return {add(y, mul(mod(x, size_x_div_4), four)), cpp_div(x, size_x_div_4)};
88+
}
89+
90+
} // namespace swizzles
91+
92+
std::pair<Val*, Val*> dispatchSwizzle(
93+
Swizzle2DType type,
94+
Val* x,
95+
Val* y,
96+
Val* maybe_size_x,
97+
Val* maybe_size_y) {
98+
switch (type) {
99+
case Swizzle2DType::ZShape:
100+
return swizzles::ZShape(x, y, maybe_size_y);
101+
case Swizzle2DType::XOR:
102+
return swizzles::Xor(x, y);
103+
case Swizzle2DType::CyclicShift:
104+
return swizzles::CyclicShift(x, y, maybe_size_x);
105+
case Swizzle2DType::Scatter:
106+
return swizzles::Scatter(x, y, maybe_size_x->evaluateInt());
107+
default:
108+
TORCH_INTERNAL_ASSERT(false, "Unsupported swizzle type");
109+
}
110+
}
111+
112+
std::pair<Val*, Val*> dispatchUnSwizzle(
113+
Swizzle2DType type,
114+
Val* x,
115+
Val* y,
116+
Val* maybe_size_x,
117+
Val* maybe_size_y) {
118+
switch (type) {
119+
case Swizzle2DType::ZShape:
120+
return swizzles::unZShape(x, y, maybe_size_y);
121+
case Swizzle2DType::XOR:
122+
return swizzles::unXor(x, y);
123+
case Swizzle2DType::CyclicShift:
124+
return swizzles::unCyclicShift(x, y, maybe_size_x);
125+
case Swizzle2DType::Scatter:
126+
return swizzles::unScatter(x, y, maybe_size_x->evaluateInt());
127+
default:
128+
TORCH_INTERNAL_ASSERT(false, "Unsupported swizzle type");
129+
}
130+
}
131+
132+
} // namespace cuda
133+
} // namespace fuser
134+
} // namespace jit
135+
} // namespace torch

0 commit comments

Comments
 (0)