12
12
AT_INT_ARRAY_REF ,
13
13
AT_SCALAR ,
14
14
AT_TENSOR ,
15
+ AT_TENSOR_LIST ,
15
16
BOOL ,
16
17
CppTestFileGen ,
17
18
DOUBLE ,
28
29
THREE_TENSOR_TUPLE ,
29
30
TWO_TENSOR_TUPLE ,
30
31
)
32
+
31
33
from torchgen .api import cpp
32
34
from torchgen .api .types import CppSignatureGroup
33
35
@@ -75,6 +77,8 @@ class ValueRef:
75
77
76
78
ValueRefList = Union [ValueRef , List [ValueRef ]]
77
79
80
+ InableCppType = frozenset ([AT_TENSOR , AT_TENSOR_LIST ])
81
+
78
82
79
83
class ComputeGraphGen :
80
84
def __init__ (self , op_reg_name : str , f : NativeFunction , suite_def : TestSuite ):
@@ -114,7 +118,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
114
118
name = f"{ arg .name } _ref" ,
115
119
src_cpp_name = arg .name ,
116
120
src_cpp_type = cpp_type ,
117
- is_in = (cpp_type == AT_TENSOR ),
121
+ is_in = (cpp_type in InableCppType ),
118
122
requires_prepack = requires_prepack ,
119
123
supports_prepack = supports_prepack ,
120
124
)
@@ -244,6 +248,25 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
244
248
ret_str += f"{ self .graph } { self .dot } add_scalar<int64_t>"
245
249
ret_str += f"({ ref .src_cpp_name } .value());\n "
246
250
return ret_str
251
+ elif ref .src_cpp_type == AT_TENSOR_LIST :
252
+ assert ref .is_in , "AT_TENSOR_LIST must be an input"
253
+ # This logic is a bit convoluted. We need to create a IOValueRef for
254
+ # each tensor, to facilate staging. On the other hand, we will
255
+ # use the .value tensor to create a ValueList, which will be passed
256
+ # to the corresponding ops.
257
+ ret_str = f"std::vector<IOValueRef> { ref .name } _io_value_refs;\n "
258
+ ret_str += f"std::vector<ValueRef> { ref .name } _value_refs;\n "
259
+ ret_str += f"for (int i=0; i < { ref .src_cpp_name } .size(); i++) {{\n "
260
+ ret_str += f" { cpp_type } io_value_ref = { self .graph } { self .dot } add_input_tensor(\n "
261
+ ret_str += f" { ref .src_cpp_name } [i].sizes().vec(),\n "
262
+ ret_str += (
263
+ f" from_at_scalartype({ ref .src_cpp_name } [i].scalar_type())); \n "
264
+ )
265
+ ret_str += f" { ref .name } _value_refs.emplace_back(io_value_ref.value);\n "
266
+ ret_str += f" { ref .name } _io_value_refs.emplace_back(io_value_ref);\n "
267
+ ret_str += "}\n "
268
+ ret_str += f"ValueRef { ref .name } = { self .graph } { self .dot } add_value_list(std::move({ ref .name } _value_refs));\n "
269
+ return ret_str
247
270
248
271
ret_str = f"{ cpp_type } { ref .name } = { self .graph } { self .dot } "
249
272
if ref .src_cpp_type == AT_TENSOR and not prepack :
@@ -288,11 +311,16 @@ def create_op_call(self) -> str:
288
311
289
312
for aten_arg in self .args :
290
313
ref = self .refs [aten_arg .name ]
291
- op_create_code += (
292
- f"{ ref .name } .value, "
293
- if (ref .is_in and not self .prepack_ref (ref )) or ref .is_out
294
- else f"{ ref .name } , "
295
- )
314
+ if ref .src_cpp_type == AT_TENSOR_LIST :
315
+ # Special case. Underlying tensors are input tensors, but the
316
+ # container itself is just a normal value.
317
+ op_create_code += f"{ ref .name } , "
318
+ else :
319
+ op_create_code += (
320
+ f"{ ref .name } .value, "
321
+ if (ref .is_in and not self .prepack_ref (ref )) or ref .is_out
322
+ else f"{ ref .name } , "
323
+ )
296
324
297
325
op_create_code += "out_ref});\n "
298
326
return op_create_code
@@ -311,22 +339,46 @@ def set_output(self, ref: ValueRefList) -> str:
311
339
312
340
def virtual_resize (self , ref : ValueRefList ) -> str :
313
341
assert isinstance (ref , ValueRef )
314
- assert ref .src_cpp_type == AT_TENSOR and ref .is_in
342
+ assert ref .src_cpp_type in InableCppType and ref .is_in
315
343
if self .prepack_ref (ref ):
316
344
return ""
317
- ret_str = f"{ self .graph } { self .dot } get_tensor({ ref .name } .value)"
318
- ret_str += f"->virtual_resize({ ref .src_cpp_name } .sizes().vec());\n "
345
+
346
+ if ref .src_cpp_type == AT_TENSOR :
347
+ ret_str = f"{ self .graph } { self .dot } get_tensor({ ref .name } .value)"
348
+ ret_str += f"->virtual_resize({ ref .src_cpp_name } .sizes().vec());\n "
349
+ elif ref .src_cpp_type == AT_TENSOR_LIST :
350
+ ret_str = ""
351
+ ret_str += f"for (int i=0; i < { ref .name } _io_value_refs.size(); i++) {{\n "
352
+ ret_str += f" { self .graph } { self .dot } get_tensor({ ref .name } _io_value_refs[i].value)"
353
+ ret_str += f"->virtual_resize({ ref .src_cpp_name } [i].sizes().vec());\n "
354
+ ret_str += "}\n "
355
+ else :
356
+ raise AssertionError (f"{ ref .src_cpp_type } not expected" )
357
+
319
358
return ret_str
320
359
321
360
def copy_into_staging (self , ref : ValueRefList ) -> str :
322
361
assert isinstance (ref , ValueRef )
323
- assert ref .src_cpp_type == AT_TENSOR and ref .is_in
362
+ assert ref .src_cpp_type in InableCppType and ref .is_in
363
+
324
364
if self .prepack_ref (ref ):
325
365
return ""
326
- ret_str = f"{ self .graph } { self .dot } copy_into_staging("
327
- ret_str += f"{ ref .name } .staging, "
328
- ret_str += f"{ ref .src_cpp_name } .const_data_ptr(), "
329
- ret_str += f"{ ref .src_cpp_name } .numel());\n "
366
+
367
+ if ref .src_cpp_type == AT_TENSOR :
368
+ ret_str = f"{ self .graph } { self .dot } copy_into_staging("
369
+ ret_str += f"{ ref .name } .staging, "
370
+ ret_str += f"{ ref .src_cpp_name } .const_data_ptr(), "
371
+ ret_str += f"{ ref .src_cpp_name } .numel());\n "
372
+ elif ref .src_cpp_type == AT_TENSOR_LIST :
373
+ ret_str = ""
374
+ ret_str += f"for (int i=0; i < { ref .name } _io_value_refs.size(); i++) {{\n "
375
+ ret_str += f" { self .graph } { self .dot } copy_into_staging("
376
+ ret_str += f"{ ref .name } _io_value_refs[i].staging, "
377
+ ret_str += f"{ ref .src_cpp_name } [i].const_data_ptr(), "
378
+ ret_str += f"{ ref .src_cpp_name } [i].numel());\n "
379
+ ret_str += "}\n "
380
+ else :
381
+ raise AssertionError (f"{ ref .src_cpp_type } not expected" )
330
382
return ret_str
331
383
332
384
def declare_vk_out_for (self , ref : Union [ValueRef , List [ValueRef ]]) -> str :
@@ -547,8 +599,10 @@ def gen_parameterization(self) -> str:
547
599
if (!is_close && t1.numel() < 500) {
548
600
std::cout << "reference: " << std::endl;
549
601
print(t1, 150);
602
+ std::cout << std::endl;
550
603
std::cout << "vulkan: " << std::endl;
551
604
print(t2, 150);
605
+ std::cout << std::endl;
552
606
}
553
607
return is_close;
554
608
}
0 commit comments