Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ class DomainMap {
}
virtual ~DomainMap() = default;

bool areExactMapped(IterDomain* id1, IterDomain* id2) const {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}

const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}
Expand Down
93 changes: 70 additions & 23 deletions torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,49 @@ class DomainMap : public pointwise_utils::DomainMap {
domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr;
}

int getPosMappedTo(TensorView* tv, IterDomain* id) const {
int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
// Find the root id mapped to `root_dim`
const auto& root_dom = tv->getRootDomain();
IterDomain* mapped_id = nullptr;
for (auto i : c10::irange(root_dom.size())) {
if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped(
root_dom[i], root_dim)) {
mapped_id = root_dom[i];
break;
}
}
TORCH_INTERNAL_ASSERT(
mapped_id != nullptr,
"Can not find ID mapped to ",
root_dim,
" in tensor ",
tv);
// Project the root id to leaf id
while (!mapped_id->uses().empty()) {
TORCH_INTERNAL_ASSERT(mapped_id->uses().size() == 1);
auto expr = mapped_id->uses()[0];
if (expr->isA<Split>()) {
mapped_id = expr->as<Split>()->inner();
} else {
auto merge = expr->as<Merge>();
TORCH_INTERNAL_ASSERT(
mapped_id == merge->inner(),
"Can not find ID mapped to ",
root_dim,
" in tensor ",
tv);
mapped_id = merge->out();
}
}
// Find the position of the leaf id
const auto& dom = tv->domain()->domain();
for (auto i : c10::irange(dom.size())) {
if (areExactMapped(id, tv->axis(i))) {
if (dom[i] == mapped_id) {
return i;
}
}
TORCH_INTERNAL_ASSERT(
false, "Can not find ID mapped to ", id, " in tensor ", tv);
false, "Can not find ID mapped to ", root_dim, " in tensor ", tv);
}

// Group inputs and outputs of a fusion by its inner most domain. For example
Expand Down Expand Up @@ -240,22 +274,37 @@ void maybeBuildVirtualInnerDims(
// both virtual innermost dim.
// 2. The satisfied one did not merge in anything. For example,
// T0[I0{1024*1024}, I1{2}]
// If this is the case, this means that we need to split the large
// inner-most dimension to satisfy the small innermost dimension
int64_t large_dim;
int64_t split_factor;
bool split_inner_most;
if (merged_size1 < params.tile_size1) {
if (params.dims_merged_with_2.empty()) {
// case 2
return;
split_inner_most = true;
large_dim = inner_most2;
split_factor = params.tile_size2;
} else {
// case 1
split_inner_most = false;
large_dim = params.dims_merged_with_2.back();
auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim];
split_factor = ceilDiv(params.tile_size2, prev_merged_size2);
}
large_dim = params.dims_merged_with_2.back();
split_factor = ceilDiv(params.tile_size1, merged_size1);
} else {
if (params.dims_merged_with_1.empty()) {
// case 2
return;
split_inner_most = true;
large_dim = inner_most1;
split_factor = params.tile_size1;
} else {
// case 1
split_inner_most = false;
large_dim = params.dims_merged_with_1.back();
auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim];
split_factor = ceilDiv(params.tile_size1, prev_merged_size1);
}
large_dim = params.dims_merged_with_1.back();
split_factor = ceilDiv(params.tile_size2, merged_size2);
}
params.split_before_tiling.push_back({large_dim, split_factor});
// adjust all dims to after-split
Expand All @@ -271,12 +320,16 @@ void maybeBuildVirtualInnerDims(
}
// Give the split-out dim to the unsatisfied one, so that both are satisfied.
if (merged_size1 < params.tile_size1) {
params.dims_merged_with_2.pop_back();
params.dims_merged_with_2.push_back(large_dim + 1);
if (!split_inner_most) {
params.dims_merged_with_2.pop_back();
params.dims_merged_with_2.push_back(large_dim + 1);
}
params.dims_merged_with_1.push_back(large_dim);
} else {
params.dims_merged_with_1.pop_back();
params.dims_merged_with_1.push_back(large_dim + 1);
if (!split_inner_most) {
params.dims_merged_with_1.pop_back();
params.dims_merged_with_1.push_back(large_dim + 1);
}
params.dims_merged_with_2.push_back(large_dim);
}
}
Expand Down Expand Up @@ -369,12 +422,6 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize) {
params->tile_size1 = 8;
params->tile_size2 = 8;
// TODO: I was trying the following but I got silent wrong result
// params->tile_size1 = 8;
// params->tile_size2 = 4;
// This should not happen, because the correctness should be irrevalent to
// schedulers. We don't have to use tile size (8, 4), but we need to fix our
// bug in codegen.
}

// Expand inner-most dims to virtual inner-most dims so that the inner-most
Expand All @@ -383,9 +430,9 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
auto inner_most_id2 = scheduler_utils::innerMostRootDim(reference2);

auto inner_most_pos1_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id1);
domain_map.getInnerLeafDim(reference1, inner_most_id1);
auto inner_most_pos2_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id2);
domain_map.getInnerLeafDim(reference1, inner_most_id2);

// See note [Supporting small transpose dimensions]
maybeBuildVirtualInnerDims(
Expand Down Expand Up @@ -643,9 +690,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {

// merge with inner most dims to get virtual inner most dims
size_t inner_most_pos1_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id1);
domain_map.getInnerLeafDim(reference1, inner_most_id1);
size_t inner_most_pos2_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id2);
domain_map.getInnerLeafDim(reference1, inner_most_id2);
if (merged1.has_value()) {
if (inner_most_pos1_in_ref1 < *merged1) {
reference1->reorder(
Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,37 @@ TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize3_CUDA) {
testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
}

// x->sin->transpose->cos->y
TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel1 run in 3.00237 ms, achieved: 715.263 GB/s
kernel2 run in 2.54362 ms, achieved: 844.264 GB/s

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR:

kernel1 run in 28.1047 ms, achieved: 76.4101 GB/s
kernel2 run in 33.0813 ms, achieved: 64.9152 GB/s

std::array<std::vector<int64_t>, 2> shapes{
std::vector<int64_t>{1024 * 1024 * 128, 2},
std::vector<int64_t>{2, 1024 * 1024 * 128}};
for (const auto& shape : shapes) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);
auto tv1 = sin(tv0);
auto tv2 = transpose(tv1, 0, 1);
auto tv3 = cos(tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn(shape, options);

auto lparams = scheduleTranspose(&fusion, {input});

FusionExecutor fe;
fe.compileFusion(&fusion, {input}, lparams);
auto outputs = fe.runFusion({input}, lparams);

auto tv_ref = input.sin().transpose(0, 1).cos();

testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
}
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)