Skip to content

Commit 57c66e4

Browse files
committed
Adding in additional placeholders and functions required for functional, array and dictionary transforms to operate while waiting for PR #5436
Signed-off-by: Ben Murray <[email protected]>
1 parent 7a41b3f commit 57c66e4

File tree

2 files changed

+76
-10
lines changed

2 files changed

+76
-10
lines changed

monai/transforms/meta_matrix.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,24 @@
1818
from monai.config import NdarrayOrTensor
1919

2020
from monai.transforms.utils import _create_rotate, _create_shear, _create_scale, _create_translate, _create_rotate_90, \
21-
_create_flip
21+
_create_flip, get_backend_from_tensor_like, get_device_from_tensor_like
2222

2323
from monai.utils import TransformBackends
2424

2525

26-
# this will conflict with PR Replacement Apply and Resample #5436
26+
def is_matrix_shaped(data):
27+
28+
return (
29+
len(data.shape) == 2 and data.shape[0] in (3, 4) and data.shape[1] in (3, 4) and data.shape[0] == data.shape[1]
30+
)
31+
32+
33+
def is_grid_shaped(data):
34+
35+
return len(data.shape) == 3 and data.shape[0] == 3 or len(data.shape) == 4 and data.shape[0] == 4
36+
37+
38+
# placeholder that will conflict with PR Replacement Apply and Resample #5436
2739
class MetaMatrix:
2840

2941
def __init__(self):
@@ -101,6 +113,23 @@ def translate(self, offsets: Union[Sequence[float], float], **extra_args):
101113
return MetaMatrix(matrix, extra_args)
102114

103115

116+
# this will conflict with PR Replacement Apply and Resample #5436
117+
def ensure_tensor(data: NdarrayOrTensor):
118+
if isinstance(data, torch.Tensor):
119+
return data
120+
121+
return torch.as_tensor(data)
122+
123+
124+
# this will conflict with PR Replacement Apply and Resample #5436
125+
def apply_align_corners(matrix, spatial_size, factory):
126+
inflated_spatial_size = tuple(s + 1 for s in spatial_size)
127+
scale_factors = tuple(s / i for s, i in zip(spatial_size, inflated_spatial_size))
128+
scale_mat = factory.scale(scale_factors)
129+
return matmul(scale_mat, matrix)
130+
131+
132+
# this will conflict with PR Replacement Apply and Resample #5436
104133
class Matrix:
105134
def __init__(self, matrix: NdarrayOrTensor):
106135
self.data = ensure_tensor(matrix)
@@ -116,6 +145,7 @@ def __init__(self, matrix: NdarrayOrTensor):
116145
# return other.__matmul__(self.data)
117146

118147

148+
# this will conflict with PR Replacement Apply and Resample #5436
119149
class Grid:
120150
def __init__(self, grid):
121151
self.data = ensure_tensor(grid)
@@ -124,6 +154,7 @@ def __init__(self, grid):
124154
# raise NotImplementedError()
125155

126156

157+
# this will conflict with PR Replacement Apply and Resample #5436
127158
class MetaMatrix:
128159
def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None):
129160

@@ -156,6 +187,7 @@ def __rmatmul__(self, other):
156187
return MetaMatrix(other_ @ self.matrix)
157188

158189

190+
# this will conflict with PR Replacement Apply and Resample #5436
159191
def matmul(
160192
left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor]
161193
):
@@ -209,6 +241,7 @@ def matmul(
209241
return result
210242

211243

244+
# this will conflict with PR Replacement Apply and Resample #5436
212245
def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor):
213246
if not is_matrix_shaped(left):
214247
raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}")
@@ -227,6 +260,7 @@ def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor):
227260
return result
228261

229262

263+
# this will conflict with PR Replacement Apply and Resample #5436
230264
def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor):
231265
if not is_grid_shaped(left):
232266
raise ValueError(
@@ -248,6 +282,7 @@ def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor):
248282
return matmul_matrix_grid(inv_matrix, left)
249283

250284

285+
# this will conflict with PR Replacement Apply and Resample #5436
251286
def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor):
252287
if not is_grid_shaped(left):
253288
raise ValueError(
@@ -270,13 +305,6 @@ def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor):
270305
return result
271306

272307

308+
# this will conflict with PR Replacement Apply and Resample #5436
273309
def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor):
274310
return left @ right
275-
276-
277-
# this will conflict with PR Replacement Apply and Resample #5436
278-
def apply_align_corners(matrix, spatial_size, factory):
279-
inflated_spatial_size = tuple(s + 1 for s in spatial_size)
280-
scale_factors = tuple(s / i for s, i in zip(spatial_size, inflated_spatial_size))
281-
scale_mat = factory.scale(scale_factors)
282-
return matmul(scale_mat, matrix)

monai/transforms/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,44 @@ def squarepulse(sig, duty: float = 0.5):
18501850
return y
18511851

18521852

1853+
# this will conflict with PR Replacement Apply and Resample #5436
1854+
def get_device_from_tensor_like(data: NdarrayOrTensor):
1855+
"""
1856+
This function returns the device of `data`, which must either be a numpy ndarray or a
1857+
pytorch Tensor.
1858+
Args:
1859+
data: the ndarray/tensor to return the device of
1860+
Returns:
1861+
None if `data` is a numpy array, or the device of the pytorch Tensor
1862+
"""
1863+
if isinstance(data, np.ndarray):
1864+
return None
1865+
elif isinstance(data, torch.Tensor):
1866+
return data.device
1867+
else:
1868+
msg = "'data' must be one of numpy ndarray or torch Tensor but is {}"
1869+
raise ValueError(msg.format(type(data)))
1870+
1871+
1872+
# this will conflict with PR Replacement Apply and Resample #5436
1873+
def get_backend_from_tensor_like(data: NdarrayOrTensor):
1874+
"""
1875+
This function returns the backend of `data`, which must either be a numpy ndarray or a
1876+
pytorch Tensor.
1877+
Args:
1878+
data: the ndarray/tensor to return the device of
1879+
Returns:
1880+
None if `data` is a numpy array, or the device of the pytorch Tensor
1881+
"""
1882+
if isinstance(data, np.ndarray):
1883+
return TransformBackends.NUMPY
1884+
elif isinstance(data, torch.Tensor):
1885+
return TransformBackends.TORCH
1886+
else:
1887+
msg = "'data' must be one of numpy ndarray or torch Tensor but is {}"
1888+
raise ValueError(msg.format(type(data)))
1889+
1890+
18531891
def value_to_tuple_range(value):
18541892
"""
18551893
Takes a value or a tuple of values.

0 commit comments

Comments
 (0)