Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def __contains__(self, op):
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ones.default,
exir_ops.edge.aten.ones_like.default,
exir_ops.edge.aten.zeros.default,
exir_ops.edge.aten.zeros_like.default,
]


Expand Down
18 changes: 14 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/Full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,26 @@ void add_full_node(
}

void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[6]);
return add_full_node(graph, args[0], args[1], args[args.size() - 1]);
}

void full_like(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[7]);
void zeros(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(
graph, args[0], graph.add_scalar<int64_t>(0), args[args.size() - 1]);
}

void ones(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(
graph, args[0], graph.add_scalar<int64_t>(1), args[args.size() - 1]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.full.default, full);
VK_REGISTER_OP(aten.full_like.default, full_like);
VK_REGISTER_OP(aten.full_like.default, full);
VK_REGISTER_OP(aten.zeros.default, zeros);
VK_REGISTER_OP(aten.zeros_like.default, zeros);
VK_REGISTER_OP(aten.ones.default, ones);
VK_REGISTER_OP(aten.ones_like.default, ones);
}

} // namespace vkcompute
19 changes: 19 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,25 @@ def get_full_inputs():
return test_suite


@register_test_suite(
[
"aten.zeros.default",
"aten.zeros_like.default",
"aten.ones.default",
"aten.ones_like.default",
]
)
def get_ones_inputs():
test_suite = VkTestSuite(
[
([S1, S2]),
([M, M1, M2]),
([L, M, M1, M2]),
]
)
return test_suite


@register_test_suite(["aten.select.int", "aten.select_copy.int"])
def get_select_int_inputs():
test_suite = VkTestSuite(
Expand Down
52 changes: 52 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,20 @@ def __init__(self):
def forward(self, x):
return torch.full(x.shape, 42.0)

class ZerosModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.zeros(x.shape)

class OnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ones(x.shape)

sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
Expand All @@ -971,6 +985,18 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
ZerosModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
OnesModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_full_like(self):
class FullLikeModule(torch.nn.Module):
def __init__(self):
Expand All @@ -979,6 +1005,20 @@ def __init__(self):
def forward(self, x):
return torch.full_like(x, 42.0)

class ZerosLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.zeros_like(x)

class OnesLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ones_like(x)

sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
Expand All @@ -987,6 +1027,18 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
ZerosLikeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
OnesLikeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self):
Expand Down