Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4503,6 +4503,73 @@ void testGPU_FusionIsOneInt() {
TORCH_CHECK(!z->isOneInt());
}

// This is to verify no cycle of computeAt is created. A more complex
// variation of this pattern appears in one of the Python tests
// (test_random_topo).
void testGPU_FusionComputeAtNonterminatingOutput() {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);

// Common intermediate tensor
auto tv1 = add(tv0, new Float(1));
// tv1 -> tv2
auto tv2 = add(tv1, new Float(2));
// tv1 -> tv3 -> tv4
auto tv3 = add(tv1, new Float(3));
auto tv4 = add(tv3, new Float(4));

// NOTE: This should no longer occur as of PR #201.
// The order of adding outputs matters. If tv3 is added before tv4,
// it should be fine. However, if tv4 is added before tv3, there
// will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created
// first, and then tv4->tv3 is created at the final phase of
// computeAt (ComputeAt::setupOutputs).
fusion.addOutput(tv2);
fusion.addOutput(tv4);
fusion.addOutput(tv3);

tv0->computeAt(tv2, -1);

TORCH_CHECK(
!(tv3->getComputeAtView() == tv4 && tv4->getComputeAtView() == tv3),
"ComputeAt cycle detected between tv3 and tv4");

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand(100, options);

torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({input});

auto& output_tv2 = outputs[0];
auto& output_tv4 = outputs[1];
auto& output_tv3 = outputs[2];

auto aten_t1 = input + 1;
auto aten_t2 = aten_t1 + 2;
auto aten_t3 = aten_t1 + 3;
auto aten_t4 = aten_t3 + 4;

TORCH_CHECK(
aten_t2.allclose(output_tv2),
"Error of: ",
aten_t2.sub(output_tv2).abs().max());
TORCH_CHECK(
aten_t3.allclose(output_tv3),
"Error of: ",
aten_t3.sub(output_tv3).abs().max());
TORCH_CHECK(
aten_t4.allclose(output_tv4),
"Error of: ",
aten_t4.sub(output_tv4).abs().max());

return;
}

} // namespace jit
} // namespace torch

Expand Down
3 changes: 2 additions & 1 deletion test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ namespace jit {
_(GPU_FusionSymbolicReduction) \
_(GPU_FusionUnrollWithAlloc) \
_(GPU_FusionIsZeroInt) \
_(GPU_FusionIsOneInt)
_(GPU_FusionIsOneInt) \
_(GPU_FusionComputeAtNonterminatingOutput)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,21 @@ void ComputeAt::setupOutputs() {
return;

std::vector<TensorView*> touched_output_order;
const auto& terminating_outputs =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need the "full" getTerminatingOutputs() here? I mean, it looks we just want the unordered set version of terminating outputs (which would also avoid the O(N^2) lookup)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. We could create a unordered_set version of the function. In practice, since the number of output tensors would be just a few, just a vector would be actually faster. In any case, since the number is pretty small, I don't think the efficiency would matter.

FusionGuard::getCurFusion()->getTerminatingOutputs();

for (auto out : FusionGuard::getCurFusion()->outputs()) {
if (out->getValType() == ValType::TensorView) {
if (tv_data.find(out->as<TensorView>()) != tv_data.end()) {
if (tv_data[out->as<TensorView>()].touched()) {
touched_output_order.push_back(out->as<TensorView>());
// No need to adjust computeAt when an output is not
// a terminating output.
if (std::find(
terminating_outputs.begin(),
terminating_outputs.end(),
out) != terminating_outputs.end()) {
touched_output_order.push_back(out->as<TensorView>());
}
}
}
}
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ std::vector<Expr*> ExprSort::getExprs(
return es.exprs;
}

std::vector<Expr*> ExprSort::getExprs(
Fusion* fusion,
const std::vector<Val*>& from) {
ExprSort es;
es.traverseFrom(fusion, from, false);
return es.exprs;
}

void InputsOf::handle(Val* v) {
if (FusionGuard::getCurFusion()->origin(v) == nullptr)
inputs.emplace(v);
Expand Down Expand Up @@ -541,6 +549,28 @@ bool Fusion::hasGridReduction() {
return false;
}

std::vector<Val*> Fusion::getTerminatingOutputs() {
FusionGuard fg(this);

std::unordered_set<Val*> used_vals;

const auto exprs = ExprSort::getExprs(
this, std::vector<Val*>(outputs().begin(), outputs().end()));

for (auto expr : exprs) {
for (auto inp : expr->inputs())
used_vals.emplace(inp);
}

std::vector<Val*> terminating_outputs;
for (auto out : outputs()) {
if (used_vals.find(out) != used_vals.end())
continue;
terminating_outputs.push_back(out);
}
return terminating_outputs;
}

} // namespace fuser
} // namespace jit
} // namespace torch
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class ExprSort : public IterVisitor {
bool from_outputs_only,
bool breadth_first,
bool respect_compute_at);

static std::vector<Expr*> getExprs(
Fusion* fusion,
const std::vector<Val*>& from);
};

class InputsOf : public IterVisitor {
Expand Down Expand Up @@ -236,6 +240,8 @@ class TORCH_CUDA_API Fusion final {
return outputs_;
}

std::vector<Val*> getTerminatingOutputs();

bool hasInput(const Val* val) const;
bool hasOutput(const Val* val) const;

Expand Down
48 changes: 2 additions & 46 deletions torch/csrc/jit/codegen/cuda/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ void IterVisitor::traverse_(
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");

if (from_outputs_only) {
auto term_outs = IterVisitor::getTerminatingOutputs(fusion);
std::vector<Val*> term_val_outs(term_outs.begin(), term_outs.end());
auto term_val_outs = fusion->getTerminatingOutputs();
if (!term_val_outs.empty())
traverseFrom(
fusion, term_val_outs, traverse_all_paths, respect_compute_at);
Expand Down Expand Up @@ -179,26 +178,6 @@ void IterVisitor::traverseAllPaths(

namespace {

// Expr sort will take a fusion and return a topologically sorted list of
// expressions.
class Exprs : public IterVisitor {
private:
std::vector<Expr*> exprs;

void handle(Expr* expr) override {
exprs.push_back(expr);
}

public:
static std::vector<Expr*> getExprs(
Fusion* fusion,
const std::vector<Val*>& from) {
Exprs ex;
ex.traverseFrom(fusion, from, false);
return ex.exprs;
}
};

// Expr sort will take a fusion and return a topologically sorted list of
// expressions.
class Inputs : public IterVisitor {
Expand All @@ -222,29 +201,6 @@ class Inputs : public IterVisitor {

} // namespace

std::unordered_set<Val*> IterVisitor::getTerminatingOutputs(
Fusion* const fusion) {
FusionGuard fg(fusion);

std::unordered_set<Val*> used_vals;

const auto exprs = Exprs::getExprs(
fusion,
std::vector<Val*>(fusion->outputs().begin(), fusion->outputs().end()));

for (auto expr : exprs) {
for (auto inp : expr->inputs())
used_vals.emplace(inp);
}

std::unordered_set<Val*> terminating_outputs;
for (auto out : fusion->outputs())
if (used_vals.find(out) == used_vals.end())
terminating_outputs.emplace(out);

return terminating_outputs;
}

std::unordered_set<Val*> IterVisitor::getInputsTo(
const std::vector<Val*>& vals) {
return Inputs::getInputs(vals);
Expand Down Expand Up @@ -323,7 +279,7 @@ void BackwardVisitor::traverseFrom(

auto vals = AllVals::get(fusion, from);

auto exprs = Exprs::getExprs(fusion, from);
auto exprs = ExprSort::getExprs(fusion, from);

{
size_t pos = 0;
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/codegen/cuda/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ class TORCH_CUDA_API IterVisitor : public OptOutDispatch {
bool breadth_first = false,
bool respect_compute_at = false);

static std::unordered_set<Val*> getTerminatingOutputs(Fusion* const);

static std::unordered_set<Val*> getInputsTo(const std::vector<Val*>& vals);
};

Expand Down
21 changes: 0 additions & 21 deletions torch/csrc/jit/codegen/cuda/lower_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,6 @@ static void IrValidate(Fusion* fusion) {
}
}

// Remove circular computeAt references
void IrFixComputeAt(Fusion* fusion) {
std::vector<Expr*> exprs = fusion->exprs(true);
std::set<TensorView*> visited;
for (auto it = exprs.rbegin(); it != exprs.rend(); it++) {
Expr* expr = *it;
if (!ir_utils::isTVOp(expr))
continue;

TensorView* tv = ir_utils::asTV(expr->output(0));
TensorView* ctv = tv->getComputeAtView();

if (ctv != nullptr && visited.find(ctv) == visited.end()) {
ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis());
tv->clearComputeAt();
}
visited.emplace(tv);
}
}

void IrBuildSizesMap(Fusion* fusion) {
// Sizes of inputs/outputs -> T.size[...]
std::unordered_map<Val*, Val*> size_map;
Expand Down Expand Up @@ -119,7 +99,6 @@ void IrAdjustMemoryTypes(Fusion* fusion) {
void PrepareForLowering(Fusion* fusion) {
FusionGuard fg(fusion);

IrFixComputeAt(fusion);
IrValidate(fusion);
IrBuildSizesMap(fusion);
IrAdjustMemoryTypes(fusion);
Expand Down
4 changes: 0 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ namespace fuser {

void TORCH_CUDA_API PrepareForLowering(Fusion* fusion);

// Compute at can have some circular references. Before we can call any tv
// with tv->getComputeAtAxis(i) we need to break those circular dependencies.
void IrFixComputeAt(Fusion* fusion);

// TensorViews are all based on symbolic sizes. When we first initialize them we
// don't know if they're inputs or outputs which would mean that they have
// runtime shapes. Intermediate tensors (those not going to global memory) do
Expand Down