Skip to content

Commit 7b330cf

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Introduce dim order utils
Summary: Common utils to work accross torch.memory_format and dim_order. This is quite restrictive as of now, we can open it up for more usecases as we go. Reviewed By: larryliu0820 Differential Revision: D47580149 fbshipit-source-id: de6017958a79b334e03dcce8a368cbde965078d6
1 parent b66faf2 commit 7b330cf

File tree

4 files changed

+106
-0
lines changed

4 files changed

+106
-0
lines changed

exir/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,11 @@ python_library(
255255
"//caffe2:torch",
256256
],
257257
)
258+
259+
python_library(
260+
name = "dim_order_utils",
261+
srcs = ["dim_order_utils.py"],
262+
deps = [
263+
"//caffe2:torch",
264+
],
265+
)

exir/dim_order_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List
8+
9+
import torch
10+
11+
"""
12+
Set of simple utilities for translating between torch.memory_format and dim_order
13+
"""
14+
15+
16+
def _get_contiguous_dim_order(ndim: int) -> List[int]:
17+
if ndim <= 0:
18+
raise AssertionError(f"Unsupported rank: {ndim}")
19+
20+
return list(range(ndim))
21+
22+
23+
def _get_channels_last_dim_order(ndim: int) -> List[int]:
24+
if ndim == 4:
25+
return [0, 2, 3, 1]
26+
27+
raise AssertionError(f"Unsupported rank: {ndim}")
28+
29+
30+
def get_memory_format(dim_order: List[int]) -> torch.memory_format:
31+
"""
32+
Given a dim_order try to map it to torch.memory_format
33+
"""
34+
if dim_order == _get_contiguous_dim_order(len(dim_order)):
35+
return torch.contiguous_format
36+
elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order(
37+
len(dim_order)
38+
):
39+
return torch.channels_last
40+
41+
raise AssertionError(
42+
f"Failed to map a given dim_order: {dim_order} to a torch.memory_format"
43+
)
44+
45+
46+
def get_dim_order(memory_format: torch.memory_format, ndim: int) -> List[int]:
47+
"""
48+
Given a memory_format and a tensor rank, generate a dim_order
49+
"""
50+
if memory_format == torch.contiguous_format:
51+
return _get_contiguous_dim_order(ndim)
52+
elif memory_format == torch.channels_last:
53+
return _get_channels_last_dim_order(ndim)
54+
55+
raise AssertionError(
56+
f"Failed to generate dim_order for a given memory format: {memory_format}"
57+
)

exir/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,15 @@ python_unittest(
420420
"//executorch/exir/verification:verifier",
421421
],
422422
)
423+
424+
python_unittest(
425+
name = "dim_order_utils",
426+
srcs = [
427+
"test_dim_order_utils.py",
428+
],
429+
supports_static_listing = True,
430+
deps = [
431+
"//caffe2:torch",
432+
"//executorch/exir:dim_order_utils",
433+
],
434+
)

exir/tests/test_dim_order_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
import unittest
9+
10+
import torch
11+
from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
12+
13+
14+
class TestDimOrderUtils(unittest.TestCase):
15+
def test_get_memory_format(self) -> None:
16+
mem_format = torch.contiguous_format
17+
for ndim in range(1, 7):
18+
dim_order = list(range(ndim))
19+
self.assertEqual(mem_format, get_memory_format(dim_order))
20+
21+
mem_format = torch.channels_last
22+
self.assertEqual(mem_format, get_memory_format([0, 2, 3, 1]))
23+
24+
def test_get_dim_order(self) -> None:
25+
for ndim in range(1, 7):
26+
self.assertEqual(
27+
list(range(ndim)), get_dim_order(torch.contiguous_format, ndim)
28+
)
29+
self.assertEqual([0, 2, 3, 1], get_dim_order(torch.channels_last, 4))

0 commit comments

Comments
 (0)