18
18
from monai .config import NdarrayOrTensor
19
19
20
20
from monai .transforms .utils import _create_rotate , _create_shear , _create_scale , _create_translate , _create_rotate_90 , \
21
- _create_flip
21
+ _create_flip , get_backend_from_tensor_like , get_device_from_tensor_like
22
22
23
23
from monai .utils import TransformBackends
24
24
25
25
26
- # this will conflict with PR Replacement Apply and Resample #5436
26
+ def is_matrix_shaped (data ):
27
+
28
+ return (
29
+ len (data .shape ) == 2 and data .shape [0 ] in (3 , 4 ) and data .shape [1 ] in (3 , 4 ) and data .shape [0 ] == data .shape [1 ]
30
+ )
31
+
32
+
33
+ def is_grid_shaped (data ):
34
+
35
+ return len (data .shape ) == 3 and data .shape [0 ] == 3 or len (data .shape ) == 4 and data .shape [0 ] == 4
36
+
37
+
38
+ # placeholder that will conflict with PR Replacement Apply and Resample #5436
27
39
class MetaMatrix :
28
40
29
41
def __init__ (self ):
@@ -101,6 +113,23 @@ def translate(self, offsets: Union[Sequence[float], float], **extra_args):
101
113
return MetaMatrix (matrix , extra_args )
102
114
103
115
116
+ # this will conflict with PR Replacement Apply and Resample #5436
117
+ def ensure_tensor (data : NdarrayOrTensor ):
118
+ if isinstance (data , torch .Tensor ):
119
+ return data
120
+
121
+ return torch .as_tensor (data )
122
+
123
+
124
+ # this will conflict with PR Replacement Apply and Resample #5436
125
+ def apply_align_corners (matrix , spatial_size , factory ):
126
+ inflated_spatial_size = tuple (s + 1 for s in spatial_size )
127
+ scale_factors = tuple (s / i for s , i in zip (spatial_size , inflated_spatial_size ))
128
+ scale_mat = factory .scale (scale_factors )
129
+ return matmul (scale_mat , matrix )
130
+
131
+
132
+ # this will conflict with PR Replacement Apply and Resample #5436
104
133
class Matrix :
105
134
def __init__ (self , matrix : NdarrayOrTensor ):
106
135
self .data = ensure_tensor (matrix )
@@ -116,6 +145,7 @@ def __init__(self, matrix: NdarrayOrTensor):
116
145
# return other.__matmul__(self.data)
117
146
118
147
148
+ # this will conflict with PR Replacement Apply and Resample #5436
119
149
class Grid :
120
150
def __init__ (self , grid ):
121
151
self .data = ensure_tensor (grid )
@@ -124,6 +154,7 @@ def __init__(self, grid):
124
154
# raise NotImplementedError()
125
155
126
156
157
+ # this will conflict with PR Replacement Apply and Resample #5436
127
158
class MetaMatrix :
128
159
def __init__ (self , matrix : Union [NdarrayOrTensor , Matrix , Grid ], metadata : Optional [dict ] = None ):
129
160
@@ -156,6 +187,7 @@ def __rmatmul__(self, other):
156
187
return MetaMatrix (other_ @ self .matrix )
157
188
158
189
190
+ # this will conflict with PR Replacement Apply and Resample #5436
159
191
def matmul (
160
192
left : Union [MetaMatrix , Grid , Matrix , NdarrayOrTensor ], right : Union [MetaMatrix , Grid , Matrix , NdarrayOrTensor ]
161
193
):
@@ -209,6 +241,7 @@ def matmul(
209
241
return result
210
242
211
243
244
+ # this will conflict with PR Replacement Apply and Resample #5436
212
245
def matmul_matrix_grid (left : NdarrayOrTensor , right : NdarrayOrTensor ):
213
246
if not is_matrix_shaped (left ):
214
247
raise ValueError (f"'left' should be a 2D or 3D homogenous matrix but has shape { left .shape } " )
@@ -227,6 +260,7 @@ def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor):
227
260
return result
228
261
229
262
263
+ # this will conflict with PR Replacement Apply and Resample #5436
230
264
def matmul_grid_matrix (left : NdarrayOrTensor , right : NdarrayOrTensor ):
231
265
if not is_grid_shaped (left ):
232
266
raise ValueError (
@@ -248,6 +282,7 @@ def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor):
248
282
return matmul_matrix_grid (inv_matrix , left )
249
283
250
284
285
+ # this will conflict with PR Replacement Apply and Resample #5436
251
286
def matmul_grid_matrix_slow (left : NdarrayOrTensor , right : NdarrayOrTensor ):
252
287
if not is_grid_shaped (left ):
253
288
raise ValueError (
@@ -270,13 +305,6 @@ def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor):
270
305
return result
271
306
272
307
308
+ # this will conflict with PR Replacement Apply and Resample #5436
273
309
def matmul_matrix_matrix (left : NdarrayOrTensor , right : NdarrayOrTensor ):
274
310
return left @ right
275
-
276
-
277
- # this will conflict with PR Replacement Apply and Resample #5436
278
- def apply_align_corners (matrix , spatial_size , factory ):
279
- inflated_spatial_size = tuple (s + 1 for s in spatial_size )
280
- scale_factors = tuple (s / i for s , i in zip (spatial_size , inflated_spatial_size ))
281
- scale_mat = factory .scale (scale_factors )
282
- return matmul (scale_mat , matrix )
0 commit comments