Skip to content

Commit b1e5ba8

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
Creation ops (#3877)
Summary: Pull Request resolved: #3877 Add more creation ops (ones/ones_like/zeros/zeros_like). Some are needed in OCR full model. Reuse full's implementation. Register them in Full.cpp Reviewed By: copyrightly, SS-JIA Differential Revision: D58247380 fbshipit-source-id: a31249396850a3c8426cda74bd6b6c9595bd2484
1 parent 98f0d82 commit b1e5ba8

File tree

4 files changed

+89
-4
lines changed

4 files changed

+89
-4
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def __contains__(self, op):
118118
exir_ops.edge.aten.clone.default,
119119
exir_ops.edge.aten.full.default,
120120
exir_ops.edge.aten.full_like.default,
121+
exir_ops.edge.aten.ones.default,
122+
exir_ops.edge.aten.ones_like.default,
123+
exir_ops.edge.aten.zeros.default,
124+
exir_ops.edge.aten.zeros_like.default,
121125
]
122126

123127

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,26 @@ void add_full_node(
6464
}
6565

6666
void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
67-
return add_full_node(graph, args[0], args[1], args[6]);
67+
return add_full_node(graph, args[0], args[1], args[args.size() - 1]);
6868
}
6969

70-
void full_like(ComputeGraph& graph, const std::vector<ValueRef>& args) {
71-
return add_full_node(graph, args[0], args[1], args[7]);
70+
void zeros(ComputeGraph& graph, const std::vector<ValueRef>& args) {
71+
return add_full_node(
72+
graph, args[0], graph.add_scalar<int64_t>(0), args[args.size() - 1]);
73+
}
74+
75+
void ones(ComputeGraph& graph, const std::vector<ValueRef>& args) {
76+
return add_full_node(
77+
graph, args[0], graph.add_scalar<int64_t>(1), args[args.size() - 1]);
7278
}
7379

7480
REGISTER_OPERATORS {
7581
VK_REGISTER_OP(aten.full.default, full);
76-
VK_REGISTER_OP(aten.full_like.default, full_like);
82+
VK_REGISTER_OP(aten.full_like.default, full);
83+
VK_REGISTER_OP(aten.zeros.default, zeros);
84+
VK_REGISTER_OP(aten.zeros_like.default, zeros);
85+
VK_REGISTER_OP(aten.ones.default, ones);
86+
VK_REGISTER_OP(aten.ones_like.default, ones);
7787
}
7888

7989
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,25 @@ def get_full_inputs():
331331
return test_suite
332332

333333

334+
@register_test_suite(
335+
[
336+
"aten.zeros.default",
337+
"aten.zeros_like.default",
338+
"aten.ones.default",
339+
"aten.ones_like.default",
340+
]
341+
)
342+
def get_ones_inputs():
343+
test_suite = VkTestSuite(
344+
[
345+
([S1, S2]),
346+
([M, M1, M2]),
347+
([L, M, M1, M2]),
348+
]
349+
)
350+
return test_suite
351+
352+
334353
@register_test_suite(["aten.select.int", "aten.select_copy.int"])
335354
def get_select_int_inputs():
336355
test_suite = VkTestSuite(

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,20 @@ def __init__(self):
963963
def forward(self, x):
964964
return torch.full(x.shape, 42.0)
965965

966+
class ZerosModule(torch.nn.Module):
967+
def __init__(self):
968+
super().__init__()
969+
970+
def forward(self, x):
971+
return torch.zeros(x.shape)
972+
973+
class OnesModule(torch.nn.Module):
974+
def __init__(self):
975+
super().__init__()
976+
977+
def forward(self, x):
978+
return torch.ones(x.shape)
979+
966980
sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
967981

968982
self.lower_module_and_test_output(
@@ -971,6 +985,18 @@ def forward(self, x):
971985
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
972986
)
973987

988+
self.lower_module_and_test_output(
989+
ZerosModule(),
990+
sample_inputs,
991+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
992+
)
993+
994+
self.lower_module_and_test_output(
995+
OnesModule(),
996+
sample_inputs,
997+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
998+
)
999+
9741000
def test_vulkan_backend_full_like(self):
9751001
class FullLikeModule(torch.nn.Module):
9761002
def __init__(self):
@@ -979,6 +1005,20 @@ def __init__(self):
9791005
def forward(self, x):
9801006
return torch.full_like(x, 42.0)
9811007

1008+
class ZerosLikeModule(torch.nn.Module):
1009+
def __init__(self):
1010+
super().__init__()
1011+
1012+
def forward(self, x):
1013+
return torch.zeros_like(x)
1014+
1015+
class OnesLikeModule(torch.nn.Module):
1016+
def __init__(self):
1017+
super().__init__()
1018+
1019+
def forward(self, x):
1020+
return torch.ones_like(x)
1021+
9821022
sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
9831023

9841024
self.lower_module_and_test_output(
@@ -987,6 +1027,18 @@ def forward(self, x):
9871027
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
9881028
)
9891029

1030+
self.lower_module_and_test_output(
1031+
ZerosLikeModule(),
1032+
sample_inputs,
1033+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1034+
)
1035+
1036+
self.lower_module_and_test_output(
1037+
OnesLikeModule(),
1038+
sample_inputs,
1039+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1040+
)
1041+
9901042
def test_vulkan_backend_reshape(self):
9911043
class ReshapeModule(torch.nn.Module):
9921044
def __init__(self):

0 commit comments

Comments
 (0)