Skip to content

Commit 82226e5

Browse files
jjsjann123csarofeen
authored andcommitted
python test fixes (#52)
fix python tests failure: 1. put Fusion inside cudaKernel to facilitate runtime arg check. 2. relax rank check for broadcast support in integration; 3. add shape propagation for newly added opeartion: [addcmul, lerp]; 4. adding utility function to create FusionGuard from CudaKernel directly.
1 parent 2f909f2 commit 82226e5

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

torch/csrc/jit/codegen/cuda/kernel.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct ExtractSizeStride {
4141
const at::Tensor& val,
4242
c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt) {
4343
if (broadcasted_size) {
44+
// [Note - broadcast support in integration]
4445
// PyTorch follows numpy broadcasting rule.
4546
// (https://numpy.org/doc/stable/user/basics.broadcasting.html)
4647
//
@@ -171,16 +172,22 @@ bool validateKernelArgTensor(
171172
msg << "Argument is a tensor, but the parameter is not.";
172173
return false;
173174
}
175+
174176
// Check the rank of the tensors.
175177
size_t arg_dim = arg.dim();
178+
// Note: This requires current Fusion to be active.
176179
size_t param_dim = TensorDomain::noReductions(
177180
static_cast<const TensorView*>(param)->getRootDomain())
178181
.size();
179-
if (arg_dim != param_dim) {
182+
// see [Note - broadcast support in integration]
183+
// Because of broadcasting support handled in integration, we relax the rank
184+
// check as necessary.
185+
if (arg_dim > param_dim) {
180186
msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is "
181187
<< param_dim;
182188
return false;
183189
}
190+
184191
if (arg.device().index() != device_index) {
185192
msg << "Argument is on device that is not compiled for";
186193
return false;
@@ -256,12 +263,15 @@ void validateKernelArgs(
256263
const CudaKernel& entry,
257264
const at::ArrayRef<IValue>& inputs,
258265
const std::vector<at::Tensor>& outputs) {
266+
// This is necessary as we were traversing the fusion graph later in the check
267+
FusionGuard fg(&entry);
259268
// Check inputs
260269
TORCH_INTERNAL_ASSERT(
261-
inputs.size() == entry.inputs.size(), "Wrong number of kernel inputs.");
270+
inputs.size() == entry.fusion_->inputs().size(),
271+
"Wrong number of kernel inputs.");
262272
for (size_t i = 0; i < inputs.size(); ++i) {
263273
const IValue& arg = inputs[i];
264-
const Val* const param = entry.inputs[i];
274+
const Val* const param = entry.fusion_->inputs()[i];
265275
std::stringstream msg;
266276
TORCH_INTERNAL_ASSERT(
267277
validateKernelArg(arg, param, entry.device_, msg),
@@ -272,15 +282,15 @@ void validateKernelArgs(
272282
}
273283

274284
TORCH_INTERNAL_ASSERT(
275-
entry.outputs.size() != 0,
285+
entry.fusion_->outputs().size() != 0,
276286
"Kernel should have at least one output tensor.");
277287

278288
TORCH_INTERNAL_ASSERT(
279-
outputs.size() == entry.outputs.size(),
289+
outputs.size() == entry.fusion_->outputs().size(),
280290
"Wrong number of kernel outputs.");
281291
for (size_t i = 0; i < outputs.size(); ++i) {
282292
const at::Tensor& arg = outputs[i];
283-
const Val* const param = entry.outputs[i];
293+
const Val* const param = entry.fusion_->outputs()[i];
284294
std::stringstream msg;
285295
TORCH_INTERNAL_ASSERT(
286296
validateKernelArgTensor(arg, param, entry.device_, msg),
@@ -546,7 +556,7 @@ void runTestKernel(
546556
input.toTensor().device().index() == entry->device_,
547557
"input to kernel on device that is not compiled for");
548558
TORCH_INTERNAL_ASSERT(
549-
!entry->outputs.empty(),
559+
!entry->fusion_->outputs().empty(),
550560
"No output found for this kernel, aborting.");
551561
if (has_reduction) {
552562
kernel_args.push(input.toTensor());

0 commit comments

Comments
 (0)