Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
bb3940f
Move MaxProducerPosUpdater into InlinePropagator::tearDown
zasdfgbnm Jul 14, 2022
644b81d
cleanup
zasdfgbnm Jul 15, 2022
d552315
add null scheduler and matching for "statically null"
shmsong Jul 15, 2022
0ea589f
Merge branch 'devel' of github.com:csarofeen/pytorch into ip-td
zasdfgbnm Jul 16, 2022
07fe5b7
Merge branch 'ip-td' into transpose-schedule
zasdfgbnm Jul 17, 2022
09a9843
save
zasdfgbnm Jul 17, 2022
887d9fa
draft that compiles
zasdfgbnm Jul 21, 2022
225a63e
save
zasdfgbnm Jul 21, 2022
9f533e9
save
zasdfgbnm Jul 21, 2022
6c52fe7
More test
zasdfgbnm Jul 21, 2022
e5b513a
cleanup tests
zasdfgbnm Jul 21, 2022
f993007
save
zasdfgbnm Jul 21, 2022
96bdcb5
fix
zasdfgbnm Jul 21, 2022
0dc891c
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Jul 21, 2022
858d6ce
more test
zasdfgbnm Jul 21, 2022
7df6e2f
new
zasdfgbnm Jul 22, 2022
9b0f02b
fix
zasdfgbnm Jul 22, 2022
419d4be
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Jul 22, 2022
0c9cdd4
fix
zasdfgbnm Jul 22, 2022
7d28da1
save
zasdfgbnm Jul 22, 2022
498ab05
lint
zasdfgbnm Jul 22, 2022
b1ed2a7
cleanup
zasdfgbnm Jul 22, 2022
a6e2215
writings
zasdfgbnm Jul 22, 2022
3d94ecc
save
zasdfgbnm Jul 22, 2022
536817a
save
zasdfgbnm Jul 22, 2022
eb97cbd
save
zasdfgbnm Jul 22, 2022
000840c
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Jul 25, 2022
275fa94
fix conflicts
zasdfgbnm Jul 25, 2022
51346fe
cleanups
zasdfgbnm Jul 25, 2022
48ecc4b
save
zasdfgbnm Jul 25, 2022
3ce7360
// TODO: support symbolic tile size
zasdfgbnm Jul 25, 2022
9bbf0cf
save
zasdfgbnm Jul 25, 2022
52e4d63
fix
zasdfgbnm Jul 26, 2022
595b5bf
fix
zasdfgbnm Jul 26, 2022
4fc2223
inline-propagator most inlined
zasdfgbnm Jul 26, 2022
2e8ce81
cleanup
zasdfgbnm Jul 26, 2022
ccfe5db
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Jul 27, 2022
276f9d5
save
zasdfgbnm Jul 27, 2022
ff54dc5
save
zasdfgbnm Jul 27, 2022
7d91989
add cache
zasdfgbnm Jul 27, 2022
84f25f4
reject trivial reduction and view in canScheduleCompileTime
zasdfgbnm Jul 27, 2022
47d516a
reorder all_heuristics
zasdfgbnm Jul 27, 2022
aa2752d
pushing some failing tests
zasdfgbnm Jul 28, 2022
d0a026b
fix reference tensor finding
zasdfgbnm Jul 28, 2022
9205565
make broadcasting test work
zasdfgbnm Jul 28, 2022
cacd1f7
cleanup
zasdfgbnm Jul 28, 2022
fc65d23
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Jul 29, 2022
cd01c0f
revert
zasdfgbnm Jul 29, 2022
9e5d394
clean
zasdfgbnm Jul 29, 2022
821d027
enable view without testing
zasdfgbnm Jul 29, 2022
1b425d2
merge all dims
zasdfgbnm Jul 29, 2022
eab982d
disable FusionScheduleTransposeBroadcast_CUDA
zasdfgbnm Jul 29, 2022
9dd5aac
cleanup & simplify things
zasdfgbnm Jul 31, 2022
d7f0ea4
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Aug 1, 2022
c19bbf1
Merge branch 'transpose-schedule' of github.com:csarofeen/pytorch int…
zasdfgbnm Aug 1, 2022
a5861a0
skip FusionScheduleTransposeBroadcast_CUDA
zasdfgbnm Aug 1, 2022
65598b2
war for transpose split support
shmsong Aug 2, 2022
9bea3bf
fix
zasdfgbnm Aug 4, 2022
55c8298
FusionScheduleTransposeComplexDAG1_CUDA
zasdfgbnm Aug 4, 2022
4dfd569
Merge branch 'devel' of github.com:csarofeen/pytorch into transpose-s…
zasdfgbnm Aug 11, 2022
6881cc6
manual test
zasdfgbnm Aug 11, 2022
4539019
save
zasdfgbnm Aug 11, 2022
9545036
save
zasdfgbnm Aug 11, 2022
6c369cf
save
zasdfgbnm Aug 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
endif()

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/normalization.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/transpose.h>

namespace torch {
namespace jit {
Expand All @@ -12,7 +13,8 @@ enum class TORCH_CUDA_CU_API ScheduleHeuristic {
None,
PointWise,
Reduction,
Persistent
Persistent,
Transpose
};
}
} // namespace fuser
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum class CompileTimeEntryType {
DOMAIN_MAP,
REFERENCE_TENSORS,
VECTORIZABLE_INPUTS_AND_OUTPUTS,
INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS,
UNROLLABLE_INPUTS_AND_OUTPUTS,
REDUCTION_TVS,
PERSISTENT_BUFFER_INFO,
Expand Down Expand Up @@ -62,6 +63,15 @@ class VectorizableInputsAndOutputs {
CompileTimeEntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS;
};

//! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`,
//! stores the fusion's inputs and outputs grouped by inner most dimension.
class InputsOutputsInnerDimGroups {
public:
using DataType = std::vector<std::vector<TensorView*>>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS;
};

//! Entry type definition class for `UNROLLABLE_INPUTS_AND_OUTPUTS`,
//! stores the unrollable TensorViews on a fusion's inputs and outputs.
class UnrollableInputsAndOutputs {
Expand Down
24 changes: 1 addition & 23 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h>
Expand Down Expand Up @@ -57,29 +58,6 @@ class DomainMap : public pointwise_utils::DomainMap {
return domain_map.findReferenceTensorView() != nullptr;
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool isValidReference(TensorView* output_tv) const {
if (output_tv->isFusionInput()) {
return false;
}
for (auto input_tv :
ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
continue;
}

if (fusion_->getOutputAlias(output_tv) == input_tv) {
continue;
}

if (!areAllInputIdsMappedToOutput(input_tv, output_tv)) {
return false;
}
}
return true;
}

private:
bool hasMinimumSize(TensorView* tv, int num_axes) const {
TORCH_INTERNAL_ASSERT(tv != nullptr);
Expand Down
46 changes: 24 additions & 22 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,9 @@ namespace fuser {
namespace cuda {
namespace pointwise_utils {

DomainMap::DomainMap(Fusion* fusion)
: fusion_(fusion), ca_map_(ComputeAtMap(fusion)) {
view_tvs_ = scheduler_utils::getViewTVs(fusion);
}

bool DomainMap::areExactMapped(IterDomain* id1, IterDomain* id2) {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}

// Determine if all IterDomains in input are mapped to output
bool DomainMap::areAllInputIdsMappedToOutput(
TensorView* input_tv,
TensorView* output_tv) const {
// Determine if all IterDomains in input are mapped to the given tensor
bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
const {
// Get concrete IDs for input root or rfactor domain
std::unordered_set<IterDomain*> in_concrete_ids;
for (auto in_id : input_tv->getMaybeRFactorDomain()) {
Expand All @@ -30,11 +20,9 @@ bool DomainMap::areAllInputIdsMappedToOutput(

// Erase all input concrete IDs mapped to the output domain
// Ignore unresolved broadcast dimensions
for (auto out_id : output_tv->getMaybeRFactorDomain()) {
if (!out_id->isBroadcast()) {
if (!eraseIfMapped(in_concrete_ids, out_id)) {
eraseIfInputMappedThroughViewToOutput(in_concrete_ids, out_id);
}
for (auto id : tv->getMaybeRFactorDomain()) {
if (!eraseIfMapped(in_concrete_ids, id)) {
eraseIfInputMappedThroughViewTo(in_concrete_ids, id);
}
}
return in_concrete_ids.empty();
Expand All @@ -45,7 +33,7 @@ bool DomainMap::eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
auto out_concrete_id =
ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT);
ca_map_.getConcreteMappedID(out_id, IdMappingMode::PERMISSIVE);
auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id);
bool found_match = in_concrete_id_iter != in_concrete_ids.end();
if (found_match) {
Expand All @@ -58,12 +46,12 @@ bool DomainMap::eraseIfMapped(
// Currently this function only allow having one view on the path from input to
// output. If there are multiple views, then likely the pointwise scheduler will
// reject the fusion because we can not correctly find a reference tensor.
void DomainMap::eraseIfInputMappedThroughViewToOutput(
void DomainMap::eraseIfInputMappedThroughViewTo(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
IterDomain* id) const {
for (auto view : view_tvs_) {
// Find any ID in view rfactor domain that is mapped to output ID
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id);
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), id);
if (view_rfactor_id == nullptr) {
continue;
}
Expand Down Expand Up @@ -94,6 +82,20 @@ IterDomain* DomainMap::anyMapped(
return nullptr;
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
continue;
}
if (!areAllInputIdsMappedTo(input_tv, tv)) {
return false;
}
}
return true;
}

} // namespace pointwise_utils
} // namespace cuda
} // namespace fuser
Expand Down
22 changes: 15 additions & 7 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,37 @@ namespace pointwise_utils {
// that maps to all IterDomains in the fusion.
class DomainMap {
public:
DomainMap(Fusion* fusion);
DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) {
view_tvs_ = scheduler_utils::getViewTVs(fusion);
}
virtual ~DomainMap() = default;

bool areExactMapped(IterDomain* id1, IterDomain* id2);
bool areExactMapped(IterDomain* id1, IterDomain* id2) const {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}

const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all iterDomains are mapped between input and output tvs
bool areAllInputIdsMappedToOutput(TensorView* input_tv, TensorView* output_tv)
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

// Erase input concrete ID if it is mapped to output ID
bool eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;

// Check if in_id is mapped to out_id through any view rfactor domain
void eraseIfInputMappedThroughViewToOutput(
// Check if in_id is mapped to id through any view rfactor domain
void eraseIfInputMappedThroughViewTo(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
IterDomain* id) const;

// Find any id in domain that maps with target id
IterDomain* anyMapped(
Expand Down
86 changes: 86 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/debug_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/transpose.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>

#include <limits>
Expand Down Expand Up @@ -1244,10 +1245,75 @@ class PersistentKernelScheduler : public SchedulerEntry {
}
};

class TransposeScheduler : public SchedulerEntry {
public:
explicit TransposeScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::Transpose) {
computeHeuristics(fusion, runtime_info, data_cache);
}

static bool canScheduleCompileTime(Fusion* fusion) {
// Not enabling this yet. Needs more validation.
return false;
#if 0
if (!hasAtLeastTwoValidGroups(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"cannot find two mismatching inner most dimensions");
return false;
}

// TODO: add support for trivial reduction
auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);

if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "no support for reduction ops");
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

return true;
#endif
}

static bool canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
return true;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to evaluate a bit in some size-dependent scenarios. Maybe when the inner dimensions are small.

But this function is hot path so the simpler the better.

}

void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule Transpose Fusion");
scheduleTranspose(fusion, transposeParams());
}

private:
void computeHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
params_ = getTransposeHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};

// Schedule Table
const std::vector<ScheduleHeuristic>& all_heuristics() {
static const std::vector<ScheduleHeuristic> hlist = {
ScheduleHeuristic::Reduction,
ScheduleHeuristic::Transpose,
ScheduleHeuristic::PointWise,
ScheduleHeuristic::Persistent};
return hlist;
Expand Down Expand Up @@ -1294,6 +1360,9 @@ bool SchedulerEntry::canSchedule(
case ScheduleHeuristic::Persistent:
return checkCanSchedule<PersistentKernelScheduler>(
fusion, runtime_info, data_cache);
case ScheduleHeuristic::Transpose:
return checkCanSchedule<TransposeScheduler>(
fusion, runtime_info, data_cache);
default:
TORCH_INTERNAL_ASSERT(false, "unreachable");
return false;
Expand All @@ -1320,6 +1389,10 @@ std::unique_ptr<SchedulerEntry> SchedulerEntry::makeEntry(
scheduler_entry = std::make_unique<PersistentKernelScheduler>(
fusion, runtime_info, data_cache);
break;
case ScheduleHeuristic::Transpose:
scheduler_entry = std::make_unique<TransposeScheduler>(
fusion, runtime_info, data_cache);
break;
default:
TORCH_INTERNAL_ASSERT(false, "unreachable");
}
Expand Down Expand Up @@ -1353,6 +1426,8 @@ std::string toString(ScheduleHeuristic sh) {
return "reduction";
case ScheduleHeuristic::Persistent:
return "persistent";
case ScheduleHeuristic::Transpose:
return "transpose";
default:
TORCH_INTERNAL_ASSERT(false, "undefined schedule");
}
Expand Down Expand Up @@ -1405,6 +1480,10 @@ HeuristicSummary::HeuristicSummary(
getPersistentHeuristics(fusion, runtime_info, this);
PersistentKernelScheduler::canScheduleRunTime(fusion, runtime_info, this);
break;
case ScheduleHeuristic::Transpose:
getTransposeHeuristics(fusion, runtime_info, this);
TransposeScheduler::canScheduleRunTime(fusion, runtime_info, this);
break;
default:
TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
}
Expand Down Expand Up @@ -1451,6 +1530,11 @@ void HeuristicSummary::validate() const {
entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO));
break;
}
case ScheduleHeuristic::Transpose: {
TORCH_INTERNAL_ASSERT(entry_type_map_.count(
EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS));
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
}
Expand Down Expand Up @@ -1490,6 +1574,8 @@ template class HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>;
template class HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::VectorizableInputsAndOutputs>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::InputsOutputsInnerDimGroups>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::UnrollableInputsAndOutputs>;
template class HeuristicSummaryEntry<HeuristicCompileTime::ReductionTVs>;
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ class TORCH_CUDA_CU_API SchedulerEntry {
return *pparams;
}

const TransposeParams& transposeParams() const {
auto tparams = std::dynamic_pointer_cast<TransposeParams>(params_);
TORCH_INTERNAL_ASSERT(
tparams != nullptr, "Heuristic parameter is not a transpose parameter");
return *tparams;
}

void updateLaunchConstraint(const LaunchParams& launch_params) {
params_->lparams = launch_params;
}
Expand Down
Loading