Skip to content

Cyclic ComputeAt #200

@naoyam

Description

@naoyam

🐛 Bug

A cycle of computeAt tensors can occur in the following case:

(Copied from testGPU_FusionComputeAtNonterminatingOutput in https://github.com/naoyam/pytorch/blob/fix-compute-at-cycle/test/cpp/jit/test_gpu.cpp)

  TensorView* tv0 = makeDummyTensor(1);
  fusion.addInput(tv0);

  // Common intermediate tensor
  auto tv1 = add(tv0, new Float(0));
  // tv1 -> tv2
  auto tv2 = add(tv1, new Float(0));
  // tv1 -> tv3 -> tv4
  auto tv3 = add(tv1, new Float(0));
  auto tv4 = add(tv3, new Float(0));

  // The order of adding outputs matters. If tv3 is added before tv4,
  // it should be fine. However, if tv4 is added before tv3, there
  // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created
  // first, and then tv4->tv3 is created at the final phase of
  // computeAt (ComputeAt::setupOutputs).
  if (true) {
    // A cycle of tv3 <-> tv4 will be created.
    fusion.addOutput(tv2);
    fusion.addOutput(tv4);
    fusion.addOutput(tv3);
  } else {
    // This should work fine.
    fusion.addOutput(tv2);
    fusion.addOutput(tv3);
    fusion.addOutput(tv4);
  }

  tv0->computeAt(tv2, -1);

After the computeAt, tv3 is set to be computed at tv4, and tv4 is computed at tv3. Cycles like this are not supposed to happen. See testGPU_FusionComputeAtNonterminatingOutput in https://github.com/naoyam/pytorch/blob/fix-compute-at-cycle/test/cpp/jit/test_gpu.cpp for a complete reproducer.

The code generation still works because IrFix breaks such cycles. However, we want that computeAt cycles never occur so that IrFix can be removed.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions