diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 026f1db9273..aa3cca5f384 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -540,6 +540,7 @@ def register_view_op(features: OpFeatures): exir_ops.edge.aten.ones.default, exir_ops.edge.aten.ones_like.default, exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.upsample_bilinear2d.vec, exir_ops.edge.aten.zeros.default, exir_ops.edge.aten.zeros_like.default, exir_ops.edge.et_vk.grid_priors.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/upsample_2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/upsample_2d.glsl new file mode 100644 index 00000000000..85b63ad20ba --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/upsample_2d.glsl @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec3", "in_limits")} +${layout_declare_ubo(B, "vec2", "recip_scales")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int align_corners = 0; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + ivec2 max_in_xy = in_limits.xy - 1; + vec2 scaled_xy; + + if (align_corners == 1) { + scaled_xy = pos.xy * recip_scales; + } else { + scaled_xy = (pos.xy + 0.5) * recip_scales - 0.5; + } + + $if MODE == "nearest": + const ivec2 ipos = clamp(ivec2(round(scaled_xy)), ivec2(0), max_in_xy); + VEC4_T out_tex = texelFetch(t_in, ivec3(ipos, pos.z), 0); + $elif MODE == "bilinear": + vec2 upper_xy = ceil(scaled_xy); + vec2 lower_xy = floor(scaled_xy); + + // Clamp coordinates to valid input range + upper_xy = clamp(upper_xy, ivec2(0), max_in_xy); + lower_xy = clamp(lower_xy, ivec2(0), max_in_xy); + + // Calculate interpolation weights + vec2 interp_weights = (scaled_xy - lower_xy); + + // Sample the four nearest texels + VEC4_T sample00 = texelFetch(t_in, ivec3(lower_xy.x, lower_xy.y, pos.z), 0); + VEC4_T sample10 = texelFetch(t_in, ivec3(upper_xy.x, lower_xy.y, pos.z), 0); + VEC4_T sample01 = texelFetch(t_in, ivec3(lower_xy.x, upper_xy.y, pos.z), 0); + VEC4_T sample11 = texelFetch(t_in, ivec3(upper_xy.x, upper_xy.y, pos.z), 0); + + // Perform bilinear interpolation + VEC4_T out_tex = mix( + mix(sample00, sample10, interp_weights.x), + mix(sample01, sample11, interp_weights.x), + interp_weights.y + ); + + imageStore(t_out, pos, out_tex); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/upsample_2d.yaml similarity index 83% rename from backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml rename to backends/vulkan/runtime/graph/ops/glsl/upsample_2d.yaml index 89b873c02c3..3bd1c282e13 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/upsample_2d.yaml @@ -4,15 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -upsample_nearest2d: +upsample_2d: parameter_names_with_default_values: - NDIM: 3 DTYPE: float - PACKING: C_packed STORAGE: texture3d + MODE: nearest generate_variant_forall: DTYPE: - VALUE: half - VALUE: float shader_variants: - NAME: upsample_nearest2d + - NAME: upsample_bilinear2d + MODE: bilinear diff --git a/backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl deleted file mode 100644 index 8ab455a55a0..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#include "broadcasting_utils.h" -#include "indexing_utils.h" - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} - -layout(std430) buffer; - -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} -${layout_declare_ubo(2, "ivec3", "out_limits")} -${layout_declare_ubo(3, "ivec2", "input_size")} -${layout_declare_ubo(4, "vec2", "rev_scales")} - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits))) { - return; - } - - const ivec2 ipos = clamp(ivec2(pos.xy * rev_scales), ivec2(0), input_size); - - VEC4_T in_texel = texelFetch(t_in, ivec3(ipos, pos.z), 0); - imageStore(t_out, pos, in_texel); -} diff --git a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp index 73f8055c284..79777b3f9ac 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp @@ -16,6 +16,8 @@ namespace vkcompute { +enum class UpsampleMode : int { NEAREST, BILINEAR }; + void resize_upsample_nearest2d_node( ComputeGraph* graph, const std::vector& args, @@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node( out->virtual_resize(out_sizes); } -// ExecuTorch-Vulkan framework to add node -// Args: -// in: will be converted from NCHW input tensor to 3D ARGB representation in -// openGL (via ExecuTorch) output_sizes: optional 2D array of targetting -// output size of H and W dimensions. >= input sizes; - -// will be computed if only given the scale_factors. -// scale_factors: optional 2D array of scale factors for H and W dimensions. -// Will be computed if only given the output_sizes. void add_upsample_nearest2d_node( ComputeGraph& graph, + const UpsampleMode mode, const ValueRef in, const ValueRef output_sizes, + const ValueRef align_corners, const ValueRef scale_factors, const ValueRef out) { if (graph.val_is_none(output_sizes) && graph.val_is_none(scale_factors)) { @@ -63,36 +58,61 @@ void add_upsample_nearest2d_node( "Invalid input, must provide ONLY one of output_sizes or scale_factors"); } - vTensorPtr t_in = graph.get_tensor(in); - utils::uvec3 input_sizes = t_in->logical_limits(); + int align_corners_val = 0; + if (is_valid(align_corners) && graph.get_bool(align_corners)) { + align_corners_val = 1; + } + + utils::uvec3 in_limits = graph.logical_limits_of(in); + utils::uvec3 out_limits = graph.logical_limits_of(out); + + uint32_t out_width = out_limits[0u]; + uint32_t out_height = out_limits[1u]; - utils::ivec2 input_size = { - utils::safe_downcast(input_sizes[0]), - utils::safe_downcast(input_sizes[1])}; - utils::vec2 rev_scales = { - utils::safe_downcast(1.0), utils::safe_downcast(1.0)}; + float scale_factor_x = float(in_limits[0u]) / float(out_width); + float scale_factor_y = float(in_limits[1u]) / float(out_height); + + float recip_scale_factor_x = 1.0f / scale_factor_x; + float recip_scale_factor_y = 1.0f / scale_factor_y; - // Reverse scale factors that pre-computed before GLSL. if (!graph.val_is_none(output_sizes)) { - auto output_size_ref = graph.get_int_list(output_sizes); - rev_scales = { - utils::safe_downcast( - (float)input_size[0] / output_size_ref->at(1)), - utils::safe_downcast( - (float)input_size[1] / output_size_ref->at(0))}; + IntListPtr output_size_ref = graph.get_int_list(output_sizes); + out_width = output_size_ref->at(1); + out_height = output_size_ref->at(0); + + VK_CHECK_COND(out_width == out_limits[0u]); + VK_CHECK_COND(out_height == out_limits[1u]); + + } else { + DoubleListPtr scales = graph.get_double_list(scale_factors); + scale_factor_x = scales->at(1); + scale_factor_y = scales->at(0); + VK_CHECK_COND(in_limits[0u] * scale_factor_x == out_width); + VK_CHECK_COND(in_limits[1u] * scale_factor_y == out_height); + } + + if (align_corners_val == 1) { + recip_scale_factor_x = float(in_limits[0u] - 1) / float(out_width - 1); + recip_scale_factor_y = float(in_limits[1u] - 1) / float(out_height - 1); } else { - auto scales = graph.get_double_list(scale_factors); - rev_scales = { - utils::safe_downcast(1.0 / scales->at(1)), - utils::safe_downcast(1.0 / scales->at(0))}; + recip_scale_factor_x = float(in_limits[0u]) / float(out_width); + recip_scale_factor_y = float(in_limits[1u]) / float(out_height); } - vTensorPtr t_out = graph.get_tensor(out); + utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y}; - std::string kernel_name("upsample_nearest2d"); + std::string kernel_name; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + switch (mode) { + case UpsampleMode::NEAREST: + kernel_name = "upsample_nearest2d"; + break; + case UpsampleMode::BILINEAR: + kernel_name = "upsample_bilinear2d"; + break; + } + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -103,21 +123,44 @@ void add_upsample_nearest2d_node( {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, // Shader params buffers - {t_out->logical_limits_ubo(), - graph.create_params_buffer(input_size), - graph.create_params_buffer(rev_scales)}, + {graph.logical_limits_ubo(out), + graph.logical_limits_ubo(in), + graph.create_params_buffer(recip_scales)}, // Specialization Constants - {}, + {align_corners_val}, resize_upsample_nearest2d_node, {output_sizes, scale_factors})); } -void upsample(ComputeGraph& graph, const std::vector& args) { - return add_upsample_nearest2d_node(graph, args[0], args[1], args[2], args[3]); +void upsample_nearest2d( + ComputeGraph& graph, + const std::vector& args) { + return add_upsample_nearest2d_node( + graph, + UpsampleMode::NEAREST, + args[0], + args[1], + kDummyValueRef, + args[2], + args[3]); +} + +void upsample_bilinear2d( + ComputeGraph& graph, + const std::vector& args) { + return add_upsample_nearest2d_node( + graph, + UpsampleMode::BILINEAR, + args[0], + args[1], + args[2], + args[3], + args[4]); } REGISTER_OPERATORS { - VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample); + VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample_nearest2d); + VK_REGISTER_OP(aten.upsample_bilinear2d.vec, upsample_bilinear2d); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index a1b03db27c9..f97b2c51370 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -430,21 +430,34 @@ def get_native_layer_norm_inputs(): return test_suite -@register_test_suite("aten.upsample_nearest2d.vec") def get_upsample_inputs(): - test_suite = VkTestSuite( - [ - # (input tensor shape, output 2D image size (H, W), output scaling factors) - ((2, 2, 2, 2), None, [1, 1]), - ((1, 1, 2, 2), None, [2, 2]), - ((1, 1, 2, 2), None, [2, 4]), - ((1, 1, 2, 2), None, [4, 2]), - ((1, 1, 2, 2), [2, 2], None), - ((1, 1, 2, 2), [2, 4], None), - ((1, 1, 2, 2), [3, 2], None), - ] - ) - return test_suite + inputs_list = [ + # (input tensor shape, output 2D image size (H, W), output scaling factors) + ((2, 2, 2, 2), None, [1, 1]), + ((1, 1, 2, 2), None, [2, 2]), + ((1, 1, 2, 2), None, [2, 4]), + ((1, 1, 2, 2), None, [4, 2]), + ((1, 1, 2, 2), [2, 2], None), + ((1, 1, 2, 2), [2, 4], None), + ((1, 1, 2, 2), [3, 2], None), + ] + return inputs_list + + +@register_test_suite("aten.upsample_nearest2d.vec") +def get_upsample_nearest2d_inputs(): + inputs_list = get_upsample_inputs() + return VkTestSuite(inputs_list) + + +@register_test_suite("aten.upsample_bilinear2d.vec") +def get_upsample_bilinear2d_inputs(): + base_inputs_list = get_upsample_inputs() + inputs_list = [] + for input_case in base_inputs_list: + inputs_list.append((input_case[0], input_case[1], False, input_case[2])) + inputs_list.append((input_case[0], input_case[1], True, input_case[2])) + return VkTestSuite(inputs_list) @register_test_suite(["aten.full.default", "aten.full_like.default"])