@@ -1065,6 +1065,9 @@ def forward(self, k: torch.Tensor) -> torch.Tensor:
1065
1065
self .check_tensor_buffer_loc (1 , execution_plan .values , 0 , 1 , 48 )
1066
1066
1067
1067
def test_emit_prims (self ) -> None :
1068
+ tensor_output = torch .rand (1 , 4 )
1069
+ tensor_list_output = [torch .rand (1 , 4 ), torch .rand (1 , 4 )]
1070
+
1068
1071
class Simple (torch .nn .Module ):
1069
1072
def __init__ (self ) -> None :
1070
1073
super ().__init__ ()
@@ -1078,6 +1081,12 @@ def get_ints(self) -> Tuple[int]:
1078
1081
def get_str (self ) -> str :
1079
1082
return "foo"
1080
1083
1084
+ def get_tensor (self ) -> torch .Tensor :
1085
+ return tensor_output
1086
+
1087
+ def get_tensor_list (self ) -> List [torch .Tensor ]:
1088
+ return tensor_list_output
1089
+
1081
1090
def forward (self , x : torch .Tensor ) -> torch .Tensor :
1082
1091
return torch .nn .functional .sigmoid (self .linear (x ))
1083
1092
@@ -1090,9 +1099,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1090
1099
getters = {}
1091
1100
getters ["get_ints" ] = model .get_ints ()
1092
1101
getters ["get_str" ] = model .get_str ()
1093
- print (getters ["get_str" ])
1102
+ getters ["get_tensor" ] = model .get_tensor ()
1103
+ getters ["get_tensor_list" ] = model .get_tensor_list ()
1104
+
1094
1105
merged_program = emit_program (exir_input , False , getters ).program
1095
- self .assertEqual (len (merged_program .execution_plan ), 3 )
1106
+
1107
+ self .assertEqual (len (merged_program .execution_plan ), 5 )
1096
1108
1097
1109
self .assertEqual (
1098
1110
merged_program .execution_plan [0 ].name ,
@@ -1106,6 +1118,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1106
1118
merged_program .execution_plan [2 ].name ,
1107
1119
"get_str" ,
1108
1120
)
1121
+ self .assertEqual (
1122
+ merged_program .execution_plan [3 ].name ,
1123
+ "get_tensor" ,
1124
+ )
1125
+ self .assertEqual (
1126
+ merged_program .execution_plan [4 ].name ,
1127
+ "get_tensor_list" ,
1128
+ )
1129
+
1109
1130
# no instructions in a getter
1110
1131
self .assertEqual (
1111
1132
len (merged_program .execution_plan [1 ].chains [0 ].instructions ),
@@ -1141,6 +1162,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1141
1162
merged_program .execution_plan [2 ].values [0 ].val .string_val ,
1142
1163
"foo" ,
1143
1164
)
1165
+ self .assertEqual (len (merged_program .execution_plan [3 ].outputs ), 1 )
1166
+ self .assertEqual (len (merged_program .execution_plan [4 ].outputs ), 2 )
1167
+
1168
+ merged_program = to_edge (
1169
+ export (model , inputs ), constant_methods = getters
1170
+ ).to_executorch ()
1171
+ executorch_module = _load_for_executorch_from_buffer (merged_program .buffer )
1172
+ torch .allclose (executorch_module .run_method ("get_tensor" , [])[0 ], tensor_output )
1173
+ model_output = executorch_module .run_method ("get_tensor_list" , [])
1174
+ for i in range (len (tensor_list_output )):
1175
+ torch .allclose (model_output [i ], tensor_list_output [i ])
1144
1176
1145
1177
def test_emit_debug_handle_map (self ) -> None :
1146
1178
mul_model = Mul ()
0 commit comments