diff --git a/src/targets/gpu/jit/pad.cpp b/src/targets/gpu/jit/pad.cpp new file mode 100644 index 00000000000..962fdd1cb3e --- /dev/null +++ b/src/targets/gpu/jit/pad.cpp @@ -0,0 +1,100 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +using namespace migraphx::gpu::gen; // NOLINT + +static const char* const pointwise_kernel = R"__migraphx__( +#include +#include +#include +#include + +namespace migraphx { + +extern "C" { +__global__ void pad_kernel(void* input_p, void* output_p) +{ + auto offsets = index_ints<${offsets}>{}; + auto idx = make_index(); + make_tensors()(input_p, output_p)([&](auto input, auto output) { + pad(idx, offsets, input, output, ${pad_val}); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct pad_compiler : compiler +{ + std::vector names() const { return {"pad"}; } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + options.inputs = inputs; + options.output = inputs.back(); + options.virtual_inputs = reduce_dims(inputs); + options.kernel_name = "pad_kernel"; + options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); + + auto pad_val = v.get("value", 0.f); + auto pad_val_string = to_string(pad_val); + if(float_equal(pad_val, std::numeric_limits::lowest())) + pad_val_string = "lowest{}"; + if(float_equal(pad_val, std::numeric_limits::max())) + pad_val_string = "highest{}"; + + auto padding = v.at("pads").to_vector(); + auto input_lens = inputs.front().lens(); + std::vector offsets(input_lens.size()); + std::copy(padding.begin(), padding.begin() + offsets.size(), offsets.begin()); + + auto src = interpolate_string( + pointwise_kernel, + {{"pad_val", to_string(pad_val_string)}, {"offsets", to_string_range(offsets)}}); + return compile_hip_code_object(src, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + { + return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); + } +}; +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp new file mode 100644 index 00000000000..d4dcb49dfad --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp @@ -0,0 +1,63 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_KERNELS_PAD_HPP +#define MIGRAPHX_GUARD_KERNELS_PAD_HPP + +#include +#include +#include +#include + +namespace migraphx { + +template +__device__ void pad(const index& idx, + const Offsets& offsets, + const Input& input, + Output& output, + const PadVal& pad_val) +{ + auto output_shape = output.get_shape(); + idx.global_stride(output_shape.elements(), [&](auto i) { + // 1. get current multi-index for output + // 2. get the size of the input to determine input boundaries + // 3. compute the corresponding multi-index for input by accounting for offsets + // 4. if current multi-index is within offsets or input's new multi-index is out of bounds, + // use pad value instead of input's value + auto multi = output_shape.multi(i); + auto input_bounds = input.get_shape().lens; + auto input_idx = multi - offsets; + auto range_multi = range(multi.size()); + + if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { + return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; + })) + output[multi] = pad_val; + else + output[multi] = input[input_idx]; + }); +} + +} // namespace migraphx +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp new file mode 100644 index 00000000000..af32a723b5a --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_KERNELS_RANGES_HPP +#define MIGRAPHX_GUARD_KERNELS_RANGES_HPP + +#include + +namespace migraphx { + +template +struct iterator_range +{ + Iterator start; + Iterator last; + + constexpr Iterator begin() const { return start; } + + constexpr Iterator end() const { return last; } +}; + +constexpr iterator_range range(diff_int start, diff_int last) +{ + return {{start, {}}, {last, {}}}; +} +constexpr iterator_range range(diff_int last) { return range(0, last); } + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_RANGES_HPP diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index a85eb96dc9f..f0109d3f3c7 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -104,7 +104,6 @@ struct miopen_apply add_extend_op("lrn"); add_extend_op("multinomial"); add_extend_op("nonzero"); - add_extend_op("pad"); add_extend_op("pooling"); add_extend_op("prefix_scan_sum"); add_extend_op("reverse"); diff --git a/test/ref_ops_test.cpp b/test/ref_ops_test.cpp index 5d89c94323b..6de1a52fa49 100644 --- a/test/ref_ops_test.cpp +++ b/test/ref_ops_test.cpp @@ -3884,6 +3884,21 @@ TEST_CASE(pad_test) EXPECT(migraphx::verify_range(results_vector, gold)); } +TEST_CASE(pad_test_asym) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}}); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(9); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 2, 0, 3, 4, 0, 0, 0, 0}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + TEST_CASE(pad_test_highest_half) { migraphx::program p; diff --git a/test/verify/test_pad_large.cpp b/test/verify/test_pad_large.cpp new file mode 100644 index 00000000000..01dce2d783d --- /dev/null +++ b/test/verify/test_pad_large.cpp @@ -0,0 +1,42 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pad_large : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {586, 3, 224, 224}}; + std::vector pads0 = {0, 0, 1, 1, 0, 0, 1, 1}; + auto l0 = mm->add_parameter("x", s0); + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads0}}), l0); + return p; + } +};