forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Description
🐛 Bug
This might as well not be a bug but just my misunderstanding is totally off. (This has happened already too many times). But discussion might still be educational (to me at least).
I found that computeAt
propagates bindings from layers past the nested source to input.
This is the simplified description to illustrate the idea.
- Data flow
T0 -> T2 -> T1
, T2
is the intermediate generated fromT1->rFactor
- I specify
T0->computeAt(T2)
andT1->axis(0)->parallelize(ParallelType::BIDx);
The strange behavior here is:
codegen seems to be able to implicitly propagate the BIDx
binding to axis(0) all the way back to T0
, even though it is not explicitly specified via computeAt;
However, BIDx
binding to axis(0) is not considered for T2 during allocation. It complains about non-constant allocation.
- Note that, here we are not specifying
T2->computeAt(T1)
, which if added would have explicit thread binding propagated from T1 back to T2 and T0, and it would resolve the issue.
I think from the observed behavior, we probably have an assert that could be relaxed on conditions.
To Reproduce
void testGPU_FusionReductionJ() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
fusion.addOutput(tv1);
TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
tv1->split(1, 128);
auto tv2 = tv1->rFactor({1});
std::cout << "fusion 1: \n" << fusion << std::endl;
tv0->computeAt(tv2, 2);
std::cout << "fusion 2: \n" << fusion << std::endl;
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv0->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
// If we can figure out the thread binding for tv0 in the generated kernel without specifying it here, we should be able to do the same thing for the allocation of tv2 as well.
//tv0->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(0)->parallelize(ParallelType::BIDx);
std::cout << "fusion 3: \n" << fusion << std::endl;
GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::cout << cdg.str() << std::endl;
}
Metadata
Metadata
Assignees
Labels
No labels