Skip to content

Commit e96aacf

Browse files
authored
Enable Transpose operation (#1882)
1 parent 425dce2 commit e96aacf

File tree

9 files changed

+384
-16
lines changed

9 files changed

+384
-16
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,16 @@ namespace c10 {
4949
_(prim, oneDNNFusionGuard) \
5050
_(prim, FunctionalGraph) \
5151
_(prim, add_optional) \
52-
_(prim, view_copy) \
52+
_(prim, expand_copy) \
53+
_(prim, expand_as_copy) \
54+
_(prim, flatten_copy) \
55+
_(prim, permute_copy) \
5356
_(prim, reshape_copy) \
5457
_(prim, squeeze_copy) \
58+
_(prim, t_copy) \
59+
_(prim, transpose_copy) \
5560
_(prim, unsqueeze_copy) \
56-
_(prim, flatten_copy) \
57-
_(prim, expand_copy) \
58-
_(prim, expand_as_copy) \
61+
_(prim, view_copy) \
5962
_(prim, DifferentiableGraph) \
6063
_(prim, TensorExprGroup) \
6164
_(prim, TensorExprDynamicGroup) \

test/test_jit_cuda_fuser.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4457,6 +4457,122 @@ def t(x, w):
44574457
self.assertEqual(jit_o, o)
44584458
self.assertGraphContainsExactly(t_jit.graph_for(x, w), FUSION_GUARD, 2, consider_subgraphs=True)
44594459

4460+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4461+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4462+
"Requires fusion optimization pass to be effective")
4463+
def test_view_before_permute(self):
4464+
view_examples = [[[1, 19, 1, 12, 7, 1, 99], [1, 19, 1, 3, 2772]],
4465+
[[3, 17, 80, 1], [51, 1, 2, 4, 10]],
4466+
[[3, 17, 80, 1, 9], [51, 1, 2, 4, 10, 9]],
4467+
[[2, 3, 4, 5], [1, 6, 1, 2, 2, 5]],
4468+
[[22, 22, 2], [22, 11, 1, 1, 4]],
4469+
[[37, 9, 7, 6, 10], [333, 2, 2, 3, 35]],
4470+
[[8, 1, 1, 8, 1, 8], [8, 2, 4, 1, 8]],
4471+
[[1, 333, 1], [1, 37, 9]],
4472+
[[1, 333], [1, 1, 1, 111, 1, 3]],
4473+
[[1, 27454, 1, 2], [1, 7844, 1, 7]],
4474+
[[1, 7844, 1, 7], [1, 27454, 2]]]
4475+
4476+
def _getTransposeAxes(sizes):
4477+
# broadcast do not change
4478+
# always move inner-most dim
4479+
# random permutation of other dims
4480+
result = []
4481+
valid_sizes = []
4482+
for idx, val in enumerate(sizes):
4483+
if val > 1 and idx < len(sizes) - 1:
4484+
valid_sizes.append((idx, val))
4485+
result.append(idx)
4486+
idx, new_size = valid_sizes[random.randint(0, len(valid_sizes) - 1)]
4487+
result[idx] = len(sizes) - 1
4488+
result[len(sizes) - 1] = idx
4489+
return result
4490+
4491+
def _transposeSize(sizes, dims):
4492+
return [sizes[old_pos] for old_pos in dims]
4493+
4494+
for example in view_examples:
4495+
before_view_size, after_view_size = example
4496+
axes = _getTransposeAxes(after_view_size)
4497+
output_size = _transposeSize(after_view_size, axes)
4498+
self._view_before_permute_helper(before_view_size, after_view_size, output_size, axes)
4499+
4500+
def _view_before_permute_helper(self, input_shape, view_shape, output_shape, dims):
4501+
def t(x, y, view_shape : List[int], dims : List[int]):
4502+
x_v = x.view(view_shape)
4503+
x_t = torch.permute(x_v, dims)
4504+
o = torch.add(x_t, y)
4505+
o = torch.relu(o)
4506+
return o
4507+
4508+
x = torch.randn(*input_shape, device="cuda")
4509+
y = torch.randn(*output_shape, device="cuda")
4510+
t_jit = torch.jit.script(t)
4511+
self._run_helper(t_jit, t, x, y, view_shape, dims)
4512+
4513+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4514+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4515+
"Requires fusion optimization pass to be effective")
4516+
def test_permute(self):
4517+
max_dims = 4
4518+
for ndims in range(2, max_dims + 1):
4519+
shape = [idx + 2 for idx in range(ndims)]
4520+
for dims in itertools.permutations(range(ndims)):
4521+
self._permute_helper(shape, dims)
4522+
4523+
def _permute_helper(self, shape, dims):
4524+
def t(x, y, dims : List[int]):
4525+
x_t = torch.permute(x, dims)
4526+
y_t = torch.permute(y, dims)
4527+
o = torch.add(x_t, y_t)
4528+
o = torch.relu(o)
4529+
return o
4530+
4531+
x = torch.randn(*shape, device="cuda")
4532+
y = torch.randn(*shape, device="cuda")
4533+
t_jit = torch.jit.script(t)
4534+
self._run_helper(t_jit, t, x, y, dims)
4535+
4536+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4537+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4538+
"Requires fusion optimization pass to be effective")
4539+
def test_transpose(self):
4540+
max_dims = 4
4541+
for ndims in range(2, max_dims + 1):
4542+
shape = [idx + 2 for idx in range(ndims)]
4543+
for idx in range(1, ndims):
4544+
for jdx in range(idx):
4545+
self._transpose_helper(shape, idx, jdx)
4546+
4547+
def _transpose_helper(self, shape, dim0, dim1):
4548+
def t(x, y, dim0 : int, dim1 : int):
4549+
x_t = torch.transpose(x, dim0, dim1)
4550+
y_t = torch.transpose(y, dim0, dim1)
4551+
o = torch.add(x_t, y_t)
4552+
o = torch.nn.functional.gelu(o)
4553+
return o
4554+
4555+
x = torch.randn(*shape, device="cuda")
4556+
y = torch.randn(*shape, device="cuda")
4557+
t_jit = torch.jit.script(t)
4558+
self._run_helper(t_jit, t, x, y, dim0, dim1)
4559+
4560+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4561+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
4562+
"Requires fusion optimization pass to be effective")
4563+
def test_transpose_default(self):
4564+
def t(x, y):
4565+
x_t = torch.t(x)
4566+
y_t = torch.t(y)
4567+
o = torch.add(x_t, y_t)
4568+
o = torch.nn.functional.gelu(o)
4569+
return o
4570+
4571+
x = torch.randn(3, 5, device="cuda")
4572+
y = torch.randn(3, 5, device="cuda")
4573+
t_jit = torch.jit.script(t)
4574+
self._run_helper(t_jit, t, x, y)
4575+
44604576
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
44614577
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
44624578
"Requires fusion optimization pass to be effective")

torch/csrc/jit/codegen/cuda/graph_fuser.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,7 +2176,10 @@ void decomposeLinearOps(Block* block) {
21762176
void replaceAliasOpsWithCopy(std::shared_ptr<Graph>& graph, Block* block) {
21772177
static std::unordered_map<Symbol, Symbol> alias_to_copy_mapping(
21782178
{{aten::expand, prim::expand_copy},
2179-
{aten::expand_as, prim::expand_as_copy}});
2179+
{aten::expand_as, prim::expand_as_copy},
2180+
{aten::permute, prim::permute_copy},
2181+
{aten::transpose, prim::transpose_copy},
2182+
{aten::t, prim::t_copy}});
21802183
// TODO: revert disabled aten::view
21812184
// ({{aten::view, prim::view_copy},
21822185
// {aten::reshape, prim::reshape_copy},
@@ -2228,7 +2231,10 @@ void replaceAliasOpsWithCopy(std::shared_ptr<Graph>& graph, Block* block) {
22282231
void revertAliasCopyOps(std::shared_ptr<Graph>& graph, Block* block) {
22292232
static std::unordered_map<Symbol, Symbol> copy_to_alias_mapping(
22302233
{{prim::expand_copy, aten::expand},
2231-
{prim::expand_as_copy, aten::expand_as}});
2234+
{prim::expand_as_copy, aten::expand_as},
2235+
{prim::permute_copy, aten::permute},
2236+
{prim::transpose_copy, aten::transpose},
2237+
{prim::t_copy, aten::t}});
22322238
// TODO: revert disabled aten::view
22332239
// ({{prim::view_copy, aten::view},
22342240
// {prim::flatten_copy, aten::flatten},

torch/csrc/jit/codegen/cuda/interface.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,62 @@ RegisterOperators reg_add_optional({
657657
aliasAnalysisFromSchema()),
658658
});
659659

660+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
661+
RegisterOperators reg_permute_copy({
662+
Operator(
663+
"prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor",
664+
[](const Node* node) -> Operation {
665+
return [node](Stack& stack) {
666+
TORCH_CHECK(
667+
node->s(attr::name) == "CudaFusionGroup",
668+
"permute_copy is only used by nvfuser to identify non-mutating ",
669+
"alias ops, should be restored after fusion pass!");
670+
IValue self, dims;
671+
pop(stack, self, dims);
672+
push(stack, at::native::view(self.toTensor(), dims.toIntVector()));
673+
};
674+
},
675+
aliasAnalysisFromSchema()),
676+
});
677+
678+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
679+
RegisterOperators reg_transpose_copy({
680+
Operator(
681+
"prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor",
682+
[](const Node* node) -> Operation {
683+
return [node](Stack& stack) {
684+
TORCH_CHECK(
685+
node->s(attr::name) == "CudaFusionGroup",
686+
"transpose_copy is only used by nvfuser to identify non-mutating ",
687+
"alias ops, should be restored after fusion pass!");
688+
IValue self, dim0, dim1;
689+
pop(stack, self, dim0, dim1);
690+
push(
691+
stack,
692+
at::transpose(self.toTensor(), dim0.toInt(), dim1.toInt()));
693+
};
694+
},
695+
aliasAnalysisFromSchema()),
696+
});
697+
698+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
699+
RegisterOperators reg_t_copy({
700+
Operator(
701+
"prim::t_copy(Tensor(a) self) -> Tensor",
702+
[](const Node* node) -> Operation {
703+
return [node](Stack& stack) {
704+
TORCH_CHECK(
705+
node->s(attr::name) == "CudaFusionGroup",
706+
"t_copy is only used by nvfuser to identify non-mutating ",
707+
"alias ops, should be restored after fusion pass!");
708+
IValue self;
709+
pop(stack, self);
710+
push(stack, at::t(self.toTensor()));
711+
};
712+
},
713+
aliasAnalysisFromSchema()),
714+
});
715+
660716
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
661717
RegisterOperators reg_view_copy({
662718
Operator(

torch/csrc/jit/codegen/cuda/manager.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,16 @@ namespace {
6262
// in the fallback path.
6363
void enableAliasCopyNodes(const std::shared_ptr<Graph>& graph, Block* block) {
6464
static std::unordered_set<Symbol> alias_copy_op(
65-
{prim::view_copy,
66-
prim::reshape_copy,
67-
prim::expand_copy,
65+
{prim::expand_copy,
6866
prim::expand_as_copy,
67+
prim::flatten_copy,
68+
prim::permute_copy,
69+
prim::reshape_copy,
6970
prim::squeeze_copy,
70-
prim::unsqueeze_copy});
71+
prim::t_copy,
72+
prim::transpose_copy,
73+
prim::unsqueeze_copy,
74+
prim::view_copy});
7175

7276
for (Node* n : block->nodes()) {
7377
for (Block* b : n->blocks()) {

torch/csrc/jit/codegen/cuda/ops/alias.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ TensorView* applyViewTransforms(
3636
TensorView* orig_tv,
3737
TensorView* post_reduce_tv,
3838
const AnalyzeViewResult& view_analysis) {
39+
TORCH_INTERNAL_ASSERT(orig_tv != nullptr, "Input is invalid.");
40+
TORCH_INTERNAL_ASSERT(post_reduce_tv != nullptr, "Input is invalid.");
3941
TORCH_INTERNAL_ASSERT(
4042
!post_reduce_tv->hasComputeAt(),
4143
"Cannot modify rfactor domain after compute at has been set.");
@@ -58,6 +60,7 @@ TensorView* applyViewTransforms(
5860
} // namespace
5961

6062
TensorView* view(TensorView* x, DataType dtype) {
63+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
6164
if (x->getDataType() == dtype) {
6265
return x;
6366
}
@@ -77,6 +80,7 @@ TensorView* view(
7780
TensorView* x,
7881
const std::vector<int64_t>& original_sizes,
7982
const std::vector<int64_t>& new_sizes) {
83+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
8084
TORCH_INTERNAL_ASSERT(
8185
TensorDomain::noReductions(x->getMaybeRFactorDomain()).size() ==
8286
original_sizes.size());
@@ -107,6 +111,7 @@ TensorView* view(
107111
}
108112

109113
TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) {
114+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
110115
auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain());
111116
if (start_dim < 0) {
112117
start_dim += inp_domain.size();
@@ -136,6 +141,7 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) {
136141
}
137142

138143
TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes) {
144+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
139145
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
140146

141147
TORCH_INTERNAL_ASSERT(
@@ -159,6 +165,7 @@ TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes) {
159165
}
160166

161167
TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes, int dim) {
168+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
162169
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
163170

164171
TORCH_INTERNAL_ASSERT(
@@ -187,6 +194,7 @@ TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes, int dim) {
187194
}
188195

189196
TensorView* unsqueeze(TensorView* x, int dim) {
197+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
190198
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
191199

192200
if (dim < 0) {
@@ -206,14 +214,28 @@ TensorView* unsqueeze(TensorView* x, int dim) {
206214
}
207215

208216
TensorView* permute(TensorView* x, const std::vector<int64_t>& new2old) {
217+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
209218
auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain());
210219
std::vector<IterDomain*> out_domain(inp_domain.size());
211220

221+
TORCH_CHECK(
222+
inp_domain.size() == new2old.size(),
223+
"The number of dimensions in the tensor input does not match the length",
224+
" of the desired ordering of dimensions i.e. input.dim() = ",
225+
inp_domain.size(),
226+
" is not equal to len(dims) = ",
227+
new2old.size());
228+
229+
// Return scalar tensors immediately
230+
if (inp_domain.size() == 0) {
231+
return set(x);
232+
}
233+
212234
auto normalized_new2old =
213235
ir_utils::normalizeNew2Old(new2old, inp_domain.size());
214236

215237
for (const auto i : c10::irange(out_domain.size())) {
216-
auto in_id = inp_domain[new2old[i]];
238+
auto in_id = inp_domain[normalized_new2old[i]];
217239
out_domain[i] = in_id->cloneWithoutRFactor();
218240
}
219241

@@ -226,6 +248,7 @@ TensorView* permute(TensorView* x, const std::vector<int64_t>& new2old) {
226248
}
227249

228250
TensorView* transpose(TensorView* x, int64_t dim0, int64_t dim1) {
251+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
229252
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
230253

231254
if (dim0 < 0) {
@@ -256,6 +279,7 @@ TensorView* transpose(TensorView* x, int64_t dim0, int64_t dim1) {
256279
}
257280

258281
TensorView* transpose(TensorView* x) {
282+
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
259283
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
260284

261285
TORCH_CHECK(

0 commit comments

Comments
 (0)