|
23 | 23 | from executorch.backends.arm.tosa_specification import TosaSpecification
|
24 | 24 |
|
25 | 25 |
|
| 26 | +# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible: |
| 27 | +# `(input + 2 * pad - kernel_size) / stride` |
| 28 | +# PyTorch however, does not require this, so as needed, we must adjust the padding. |
| 29 | +def adjust_pad_if_needed( |
| 30 | + input_size: int, kernel_size: int, stride: int, pad: int |
| 31 | +) -> int: |
| 32 | + if pad == 0: |
| 33 | + return pad |
| 34 | + |
| 35 | + mod_remainder = (input_size + 2 * pad - kernel_size) % stride |
| 36 | + |
| 37 | + # No need to adjust |
| 38 | + if mod_remainder == 0: |
| 39 | + return pad |
| 40 | + |
| 41 | + return pad - mod_remainder |
| 42 | + |
| 43 | + |
26 | 44 | @register_node_visitor
|
27 | 45 | class MaxPool2dVisitor_0_80(NodeVisitor):
|
28 | 46 | target = "aten.max_pool2d.default"
|
@@ -61,6 +79,20 @@ def define_node(
|
61 | 79 | except IndexError:
|
62 | 80 | pad_size_list = [0, 0, 0, 0]
|
63 | 81 |
|
| 82 | + # Adjust the padding as necessary |
| 83 | + pad_size_list[1] = adjust_pad_if_needed( |
| 84 | + input_tensor.shape[2], |
| 85 | + kernel_size[0], |
| 86 | + stride[0], |
| 87 | + pad_size_list[1], |
| 88 | + ) |
| 89 | + pad_size_list[3] = adjust_pad_if_needed( |
| 90 | + input_tensor.shape[3], |
| 91 | + kernel_size[1], |
| 92 | + stride[1], |
| 93 | + pad_size_list[3], |
| 94 | + ) |
| 95 | + |
64 | 96 | accumulator_type = output.dtype
|
65 | 97 |
|
66 | 98 | # Initilize zero point to zero.
|
@@ -131,6 +163,20 @@ def define_node(
|
131 | 163 | except IndexError:
|
132 | 164 | pad_size_list = [0, 0, 0, 0]
|
133 | 165 |
|
| 166 | + # Adjust the padding as necessary |
| 167 | + pad_size_list[1] = adjust_pad_if_needed( |
| 168 | + input_tensor.shape[2], |
| 169 | + kernel_size[0], |
| 170 | + stride[0], |
| 171 | + pad_size_list[1], |
| 172 | + ) |
| 173 | + pad_size_list[3] = adjust_pad_if_needed( |
| 174 | + input_tensor.shape[3], |
| 175 | + kernel_size[1], |
| 176 | + stride[1], |
| 177 | + pad_size_list[3], |
| 178 | + ) |
| 179 | + |
134 | 180 | attr = ts.TosaSerializerAttribute()
|
135 | 181 | attr.MaxPool2dAttribute(
|
136 | 182 | kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1
|
|
0 commit comments