@@ -41,6 +41,7 @@ struct ExtractSizeStride {
41
41
const at::Tensor& val,
42
42
c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt ) {
43
43
if (broadcasted_size) {
44
+ // [Note - broadcast support in integration]
44
45
// PyTorch follows numpy broadcasting rule.
45
46
// (https://numpy.org/doc/stable/user/basics.broadcasting.html)
46
47
//
@@ -171,16 +172,22 @@ bool validateKernelArgTensor(
171
172
msg << " Argument is a tensor, but the parameter is not." ;
172
173
return false ;
173
174
}
175
+
174
176
// Check the rank of the tensors.
175
177
size_t arg_dim = arg.dim ();
178
+ // Note: This requires current Fusion to be active.
176
179
size_t param_dim = TensorDomain::noReductions (
177
180
static_cast <const TensorView*>(param)->getRootDomain ())
178
181
.size ();
179
- if (arg_dim != param_dim) {
182
+ // see [Note - broadcast support in integration]
183
+ // Because of broadcasting support handled in integration, we relax the rank
184
+ // check as necessary.
185
+ if (arg_dim > param_dim) {
180
186
msg << " Argument tensor's rank is " << arg_dim << " , but the parameter is "
181
187
<< param_dim;
182
188
return false ;
183
189
}
190
+
184
191
if (arg.device ().index () != device_index) {
185
192
msg << " Argument is on device that is not compiled for" ;
186
193
return false ;
@@ -256,12 +263,15 @@ void validateKernelArgs(
256
263
const CudaKernel& entry,
257
264
const at::ArrayRef<IValue>& inputs,
258
265
const std::vector<at::Tensor>& outputs) {
266
+ // This is necessary as we were traversing the fusion graph later in the check
267
+ FusionGuard fg (&entry);
259
268
// Check inputs
260
269
TORCH_INTERNAL_ASSERT (
261
- inputs.size () == entry.inputs .size (), " Wrong number of kernel inputs." );
270
+ inputs.size () == entry.fusion_ ->inputs ().size (),
271
+ " Wrong number of kernel inputs." );
262
272
for (size_t i = 0 ; i < inputs.size (); ++i) {
263
273
const IValue& arg = inputs[i];
264
- const Val* const param = entry.inputs [i];
274
+ const Val* const param = entry.fusion_ -> inputs () [i];
265
275
std::stringstream msg;
266
276
TORCH_INTERNAL_ASSERT (
267
277
validateKernelArg (arg, param, entry.device_ , msg),
@@ -272,15 +282,15 @@ void validateKernelArgs(
272
282
}
273
283
274
284
TORCH_INTERNAL_ASSERT (
275
- entry.outputs .size () != 0 ,
285
+ entry.fusion_ -> outputs () .size () != 0 ,
276
286
" Kernel should have at least one output tensor." );
277
287
278
288
TORCH_INTERNAL_ASSERT (
279
- outputs.size () == entry.outputs .size (),
289
+ outputs.size () == entry.fusion_ -> outputs () .size (),
280
290
" Wrong number of kernel outputs." );
281
291
for (size_t i = 0 ; i < outputs.size (); ++i) {
282
292
const at::Tensor& arg = outputs[i];
283
- const Val* const param = entry.outputs [i];
293
+ const Val* const param = entry.fusion_ -> outputs () [i];
284
294
std::stringstream msg;
285
295
TORCH_INTERNAL_ASSERT (
286
296
validateKernelArgTensor (arg, param, entry.device_ , msg),
@@ -546,7 +556,7 @@ void runTestKernel(
546
556
input.toTensor ().device ().index () == entry->device_ ,
547
557
" input to kernel on device that is not compiled for" );
548
558
TORCH_INTERNAL_ASSERT (
549
- !entry->outputs .empty (),
559
+ !entry->fusion_ -> outputs () .empty (),
550
560
" No output found for this kernel, aborting." );
551
561
if (has_reduction) {
552
562
kernel_args.push (input.toTensor ());
0 commit comments