Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions torch/csrc/jit/codegen/cuda/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class CudaFusionManager {
}

void debugPrint(const TensorTypePtr& type) {
printf("\nsizes:");
if (auto sizes = type->symbolic_sizes().sizes()) {
// for (const auto& shape_symbol : sizes.value()) {
int rank = static_cast<int>(sizes->size());
Expand All @@ -165,24 +166,27 @@ class CudaFusionManager {
int rank = static_cast<int>(stride_properties->size());
printf("\nstride: ");
for (int i = 0; i < rank; i++) {
if (auto val = (*stride_properties)[i]->stride_) {
printf("%ld, ", val.value());
if ((*stride_properties)[i].has_value() &&
(*stride_properties)[i]->stride_.has_value()) {
printf("%ld, ", (*stride_properties)[i]->stride_.value());
} else {
printf("?, ");
}
}
printf("\nstride index: ");
for (int i = 0; i < rank; i++) {
if (auto val = (*stride_properties)[i]->stride_index_) {
printf("%ld, ", val.value());
if ((*stride_properties)[i].has_value() &&
(*stride_properties)[i]->stride_index_.has_value()) {
printf("%ld, ", (*stride_properties)[i]->stride_index_.value());
} else {
printf("?, ");
}
}
printf("\ncontiguous: ");
for (int i = 0; i < rank; i++) {
if (auto val = (*stride_properties)[i]->contiguous_) {
printf("%d, ", val.value());
if ((*stride_properties)[i].has_value() &&
(*stride_properties)[i]->contiguous_.has_value()) {
printf("%d, ", (*stride_properties)[i]->contiguous_.value());
} else {
printf("?, ");
}
Expand All @@ -196,7 +200,7 @@ class CudaFusionManager {
at::DimVector restorePermutation(at::DimVector permuted) {
int rank = static_cast<int>(permuted.size());
at::DimVector permutation(rank, -1);
for (int i; i < rank; i++) {
for (int i = 0; i < rank; i++) {
permutation[permuted[i]] = i;
}
return permutation;
Expand All @@ -220,17 +224,19 @@ class CudaFusionManager {

// TODO: this does not support broadcast yet;
for (int i = 0; i < rank; i++) {
if (auto index = (*stride_properties)[i]->stride_index_) {
ordered_axes.insert(*index);
if ((*stride_properties)[i].has_value() &&
(*stride_properties)[i]->stride_index_.has_value()) {
ordered_axes.insert((*stride_properties)[i]->stride_index_.value());
}
}

int unallocated_axis = 0;
// we push from slowest to fastest
for (int i = rank - 1; i >= 0; i--) {
if (auto index = (*stride_properties)[i]->stride_index_) {
// pushing axis index to current entry in permute_seq;
permute_seq.emplace_back(*index);
if ((*stride_properties)[i].has_value() &&
(*stride_properties)[i]->stride_index_.has_value()) {
permute_seq.emplace_back(
(*stride_properties)[i]->stride_index_.value());
} else {
// no designated axis for this slot, so we push an axis w/o designated
// order;
Expand Down Expand Up @@ -282,6 +288,7 @@ class CudaFusionManager {

auto strategy = getSortStrideScheme(acc_type);
// TODO: early return if permutation is no-op;

auto restore_strategy = restorePermutation(strategy);

std::vector<at::Tensor> permuted_outputs;
Expand Down