Skip to content

Commit 1811e8c

Browse files
csarofeenjjsjann123kevinstephano
authored
Transform replay refactor (#53)
Goal of this work is to have the transformation history be specific to IterDomains instead of TensorDomains. This should make it a lot easier to match up IterDomains during replay which can be complicated when taking into consideration reduction axes, rfactors, and broadcast axes. Co-authored-by: Jie <[email protected]> Co-authored-by: Kevin Stephano <[email protected]>
1 parent eda5cfd commit 1811e8c

39 files changed

+3282
-2715
lines changed

test/cpp/jit/test_gpu.cpp

Lines changed: 324 additions & 198 deletions
Large diffs are not rendered by default.

test/cpp/jit/tests.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ namespace jit {
135135
_(GPU_FusionScalarInputs) \
136136
_(GPU_FusionRFactorReplay) \
137137
_(GPU_FusionReduction) \
138-
_(GPU_FusionReduction2)
138+
_(GPU_FusionReduction2) \
139+
_(GPU_FusionReduction3) \
140+
_(GPU_FusionSimpleBCast)
139141
#else
140142
#define TH_FORALL_TESTS_CUDA(_) \
141143
_(ArgumentSpec) \

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 176 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@ namespace torch {
77
namespace jit {
88
namespace fuser {
99

10-
// Will return a new value of type val with the DataType dtype, if it's a
11-
// tensorview it will propagate the shape information from val.
12-
TORCH_CUDA_API Val* newValLike(const Val* val, DataType dtype) {
13-
switch (val->getValType().value()) {
14-
case (ValType::TensorView):
15-
return val->as<TensorView>()->newForOutput(dtype);
10+
namespace {
11+
// Will return a new value of type val with the DataType dtype.
12+
Val* newScalar(ValType vtype, DataType dtype) {
13+
switch (vtype) {
1614
case (ValType::NamedScalar):
1715
case (ValType::Scalar):
1816
switch (dtype) {
@@ -33,36 +31,92 @@ TORCH_CUDA_API Val* newValLike(const Val* val, DataType dtype) {
3331

3432
TORCH_CHECK(
3533
false,
36-
"Could not generate a new value of type ",
37-
val->getValType().value(),
38-
" with data type ",
39-
val->getDataType().value());
34+
"Was expecting a scalar type, but received ValType: ",
35+
vtype,
36+
" with DataType:",
37+
dtype);
38+
}
39+
40+
TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
41+
std::vector<TensorView*> tvs;
42+
for (auto val : vals)
43+
if (val->getValType() == ValType::TensorView)
44+
tvs.push_back(static_cast<TensorView*>(val));
45+
46+
TORCH_CHECK(
47+
!tvs.empty(),
48+
"Tried to create new output TensorView but received empty list.");
49+
50+
std::vector<IterDomain*> out_domain(
51+
tvs[0]->domain()->noReductions().size(), nullptr);
52+
53+
for (auto tv : tvs) {
54+
auto dom = tv->domain()->noReductions();
55+
TORCH_INTERNAL_ASSERT(
56+
dom.size() == out_domain.size(),
57+
"Invalid tensor view found while producing and output, it has ",
58+
dom.size(),
59+
" dimensions but expected ",
60+
out_domain.size());
61+
for (size_t i = 0; i < dom.size(); i++) {
62+
if (out_domain[i] != nullptr)
63+
continue;
64+
if (dom[i]->isBroadcast())
65+
continue;
66+
out_domain[i] = new IterDomain(dom[i]->start(), dom[i]->extent());
67+
}
68+
}
69+
70+
std::transform(
71+
out_domain.begin(),
72+
out_domain.end(),
73+
out_domain.begin(),
74+
[](IterDomain* dom) {
75+
if (dom == nullptr)
76+
return new IterDomain(
77+
new Int(0), new Int(1), ParallelType::Serial, false, false, true);
78+
return dom;
79+
});
80+
81+
return new TensorView(new TensorDomain(out_domain), dtype);
4082
}
4183

42-
TORCH_CUDA_API Val* newValLike(const Val* val) {
43-
return newValLike(val, val->getDataType().value());
84+
Val* newOutputVal(const std::vector<Val*>& vals) {
85+
TORCH_INTERNAL_ASSERT(
86+
!vals.empty(), "Cannot promote values if there aren't any.");
87+
88+
ValType out_vtype = vals[0]->getValType().value();
89+
DataType out_dtype = vals[0]->getDataType().value();
90+
91+
for (auto val : vals) {
92+
TORCH_CHECK(val->isVal(), "Invalid statement found during promotion.");
93+
TORCH_CHECK(
94+
val->getDataType().value() != DataType::Null,
95+
"Invalid datatype found during prmotion.");
96+
out_vtype = promote_type(out_vtype, val->getValType().value());
97+
out_dtype = promote_type(out_dtype, val->getDataType().value());
98+
}
99+
100+
if (out_vtype == ValType::TensorView)
101+
return newOutputTV(vals, out_dtype);
102+
103+
return newScalar(out_vtype, out_dtype);
44104
}
45105

46-
TORCH_CUDA_API Val* promoteNew(Val* v1, Val* v2) {
47-
// Can't promote two types if they aren't both
48-
// values with valid data types.
49-
TORCH_CHECK(v1->isVal() && v2->isVal());
106+
Val* newValLike(Val* val, DataType dtype) {
107+
TORCH_CHECK(val->isVal(), "Invalid statement provided to create new value.");
50108
TORCH_CHECK(
51-
v1->getDataType() != DataType::Null &&
52-
v2->getDataType() != DataType::Null);
109+
dtype != DataType::Null, "Invalid datatype provided for new value.");
53110

54-
ValType out_vtype =
55-
promote_type(v1->getValType().value(), v2->getValType().value());
56-
DataType out_dtype =
57-
promote_type(v1->getDataType().value(), v2->getDataType().value());
111+
ValType vtype = val->getValType().value();
58112

59-
if (out_vtype == v2->getValType().value())
60-
return newValLike(v2, out_dtype);
113+
if (vtype == ValType::TensorView)
114+
return newOutputTV({val}, dtype);
61115

62-
return newValLike(v1, out_dtype);
116+
return newScalar(vtype, dtype);
63117
}
64118

65-
Val* newConstScalar(DataType dtype, int val) {
119+
Val* newConstScalar(DataType dtype, long int val) {
66120
switch (dtype) {
67121
case (DataType::Int):
68122
return new Int(val);
@@ -77,7 +131,7 @@ Val* newConstScalar(DataType dtype, int val) {
77131
val);
78132
}
79133

80-
Val* newConstScalar(DataType dtype, float val) {
134+
Val* newConstScalar(DataType dtype, double val) {
81135
switch (dtype) {
82136
case (DataType::Float):
83137
return new Float(val);
@@ -92,6 +146,8 @@ Val* newConstScalar(DataType dtype, float val) {
92146
val);
93147
}
94148

149+
} // namespace
150+
95151
TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
96152
if (v1->getDataType().value() == dtype)
97153
return v1;
@@ -118,7 +174,7 @@ TORCH_CUDA_API TensorView* castOp(DataType dtype, TensorView* v1) {
118174
// UNARY OPERATIONS
119175

120176
TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1) {
121-
Val* out = newValLike(v1);
177+
Val* out = newOutputVal({v1});
122178
new UnaryOp(type, out, v1);
123179
return out;
124180
}
@@ -177,7 +233,7 @@ TensorView* arithOpOverloads(
177233
} // namespace
178234

179235
TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
180-
Val* out = promoteNew(v1, v2);
236+
Val* out = newOutputVal({v1, v2});
181237
if (is_logical_op(type)) {
182238
if (out->getDataType().value() != DataType::Bool)
183239
out = newValLike(out, DataType::Bool);
@@ -322,39 +378,73 @@ TORCH_CUDA_API TensorView* andOp(TensorView* v1, TensorView* v2) {
322378

323379
// REDUCTION OPERATIONS
324380

381+
namespace {
382+
// TODO: How do we adjust this so we can reduce to a single scalar value?
383+
TensorView* newForReduction(TensorView* tv, std::vector<unsigned int> axes) {
384+
auto orig_domain = TensorDomain::noReductions(tv->getRootDomain());
385+
std::set<unsigned int> axes_set(axes.begin(), axes.end());
386+
387+
std::vector<IterDomain*> new_domain;
388+
389+
TORCH_INTERNAL_ASSERT(
390+
!axes_set.empty(),
391+
"Asked for ouput of reduction, but no reduction axis provided.");
392+
TORCH_INTERNAL_ASSERT(
393+
(*(axes_set.rbegin())) < orig_domain.size(),
394+
"Error setting up reduction, reduction axis is outside nDims. Keep in mind reductions are relative to root domains, not modified views.");
395+
396+
for (decltype(orig_domain.size()) dim = 0; dim < orig_domain.size(); dim++) {
397+
IterDomain* id = orig_domain[dim];
398+
399+
bool isReduction = false;
400+
if ((*axes_set.begin()) == dim) {
401+
isReduction = true;
402+
axes_set.erase(axes_set.begin());
403+
}
404+
405+
new_domain.push_back(new IterDomain(
406+
id->start(), id->extent(), ParallelType::Serial, isReduction));
407+
}
408+
409+
TensorDomain* td = new TensorDomain(new_domain);
410+
return new TensorView(td, tv->getDataType().value());
411+
}
412+
413+
} // namespace
414+
325415
TensorView* reductionOp(
326416
BinaryOpType reduction_op_type,
327417
const std::vector<int>& axes,
328418
Val* init,
329-
TensorView* v1) {
419+
TensorView* tv) {
330420
TORCH_CHECK(
331421
init->isConstScalar(),
332422
"Cannot create a reduction operation where the initial value is not a const scalar.");
333423

334424
TORCH_CHECK(
335-
v1->getRootDomain() == v1->domain(),
336-
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/reorder/computeAt.");
425+
TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
426+
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
337427

338428
std::vector<unsigned int> uint_axes;
339429
for (int axis : axes) {
340430
if (axis < 0)
341-
axis += int(v1->nDims());
431+
axis += int(tv->nDims());
342432

343433
TORCH_CHECK(
344-
axis >= 0 && (unsigned int)axis < v1->nDims(),
434+
axis >= 0 && (unsigned int)axis < tv->nDims(),
345435
"Reduction on invalid axis, recieved: ",
346436
axis,
347437
" however tensor view only has ",
348-
v1->nDims(),
438+
tv->nDims(),
349439
" dims.");
350440

351441
uint_axes.push_back((unsigned int)axis);
352442
}
353443

354-
TensorView* out = v1->newForReduction(uint_axes);
355-
if (init->getDataType().value() != v1->getDataType().value())
356-
init = castOp(v1->getDataType().value(), init);
357-
new ReductionOp(reduction_op_type, init, out, v1);
444+
TensorView* out = newForReduction(tv, uint_axes);
445+
if (init->getDataType().value() != tv->getDataType().value())
446+
init = castOp(tv->getDataType().value(), init);
447+
new ReductionOp(reduction_op_type, init, out, tv);
358448
return out;
359449
}
360450

@@ -377,6 +467,48 @@ TORCH_CUDA_API TensorView* sum(TensorView* v1, const std::vector<int>& axes) {
377467
return reductionOp(BinaryOpType::Add, axes, init, v1);
378468
}
379469

470+
TORCH_CUDA_API TensorView* broadcast(
471+
TensorView* inp,
472+
const std::vector<bool>& is_broadcast_dim) {
473+
auto nBCastDims = is_broadcast_dim.size();
474+
// Validate is_broadcast_dim
475+
unsigned int n_broadcasts = 0;
476+
for (auto ent : is_broadcast_dim)
477+
if (ent)
478+
n_broadcasts++;
479+
TORCH_CHECK(
480+
nBCastDims - n_broadcasts == inp->nDims(),
481+
"Invalid broadcast, number of false entries in is_broadcast_dim expected to be ",
482+
inp->nDims(),
483+
" but received ",
484+
nBCastDims - n_broadcasts);
485+
486+
if (n_broadcasts == 0) {
487+
auto identity = unaryOp(UnaryOpType::Set, inp);
488+
TORCH_INTERNAL_ASSERT(
489+
identity->getValType().value() == ValType::TensorView,
490+
"Expected identity op, but didn't get a TensorView back.");
491+
return static_cast<TensorView*>(identity);
492+
}
493+
494+
std::vector<IterDomain*> out_domain;
495+
size_t iinp = 0, ibdim = 0;
496+
while (ibdim < is_broadcast_dim.size()) {
497+
if (is_broadcast_dim[ibdim]) {
498+
out_domain.push_back(new IterDomain(
499+
new Int(0), new Int(1), ParallelType::Serial, false, false, true));
500+
} else {
501+
out_domain.push_back(inp->axis(iinp));
502+
iinp++;
503+
}
504+
ibdim++;
505+
}
506+
TensorView* out_tensor =
507+
new TensorView(new TensorDomain(out_domain), inp->getDataType().value());
508+
new BroadcastOp(out_tensor, inp);
509+
return out_tensor;
510+
}
511+
380512
// COMPOUND OPERATIONS
381513

382514
// add_alpha
@@ -504,7 +636,7 @@ TORCH_CUDA_API Val* where(Val* c, Val* v1, Val* v2) {
504636
"Condition should be of DataType Bool, not ",
505637
c->getDataType().value());
506638

507-
Val* out = promoteNew(v1, v2);
639+
Val* out = newOutputVal({v1, v2});
508640
new TernaryOp(TernaryOpType::Where, out, c, v1, v2);
509641
return out;
510642
}
@@ -533,6 +665,8 @@ TORCH_CUDA_API TensorView* where(
533665
return arithOpOverloads(where, v1, v2, v3);
534666
}
535667

668+
// TERNARY OPERATIONS
669+
536670
TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
537671
TORCH_CHECK(
538672
in->getDataType().value() == thresh->getDataType().value() &&
@@ -544,7 +678,7 @@ TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
544678
value->getValType().value() == ValType::Scalar,
545679
"Thresh and Value values should be Scalars");
546680

547-
Val* out = newValLike(in);
681+
Val* out = newOutputVal({in});
548682

549683
new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value);
550684
return out;
@@ -565,7 +699,7 @@ TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
565699
max_val->getValType().value() == ValType::Scalar,
566700
"Min and Max values should be Scalars");
567701

568-
Val* out = newValLike(in);
702+
Val* out = newOutputVal({in});
569703

570704
new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val);
571705
return out;

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ namespace torch {
1818
namespace jit {
1919
namespace fuser {
2020

21-
// Promotion logic between two values, returns a new val from resulting type
22-
// promotion.
23-
TORCH_CUDA_API Val* promoteNew(Val* v1, Val* v2);
24-
2521
// Insertion of casting op to dtype, returns new resulting val
2622
TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1);
2723
TORCH_CUDA_API TensorView* castOp(DataType dtype, TensorView* v1);
@@ -54,6 +50,17 @@ TORCH_CUDA_API TensorView* neg(TensorView* v);
5450

5551
// BINARY OPERATIONS
5652
// add
53+
/*
54+
* Broadcasts v1 based on bool vector. Size of broadcast bool vector should be
55+
* the number of dims desired in the broadcasted tensor. This vector should be
56+
* true if output dim should be a broadcasted dim, and false if it is not a
57+
* broadcasted dim. Number of false entires must match the number of input dims.
58+
*/
59+
TORCH_CUDA_API TensorView* broadcast(
60+
TensorView* inp,
61+
const std::vector<bool>& is_broadcast_dim);
62+
63+
// BINARY OPAERATIONS
5764
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
5865
TORCH_CUDA_API TensorView* add(TensorView* v1, Val* v2);
5966
TORCH_CUDA_API TensorView* add(Val* v1, TensorView* v2);

0 commit comments

Comments
 (0)