@@ -390,7 +390,7 @@ def _affine_bounding_box_xyxy(
390
390
device = device ,
391
391
)
392
392
new_points = torch .matmul (points , transposed_affine_matrix )
393
- tr , _ = torch .min (new_points , dim = 0 , keepdim = True )
393
+ tr = torch .amin (new_points , dim = 0 , keepdim = True )
394
394
# Translate bounding boxes
395
395
out_bboxes .sub_ (tr .repeat ((1 , 2 )))
396
396
# Estimate meta-data for image with inverted=True and with center=[0,0]
@@ -701,7 +701,7 @@ def pad_image_tensor(
701
701
# internally.
702
702
torch_padding = _parse_pad_padding (padding )
703
703
704
- if padding_mode not in [ "constant" , "edge" , "reflect" , "symmetric" ] :
704
+ if padding_mode not in ( "constant" , "edge" , "reflect" , "symmetric" ) :
705
705
raise ValueError (
706
706
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
707
707
f"but got `'{ padding_mode } '`."
@@ -917,17 +917,17 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
917
917
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
918
918
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
919
919
#
920
-
920
+ # TODO: should we define them transposed?
921
921
theta1 = torch .tensor (
922
922
[[[coeffs [0 ], coeffs [1 ], coeffs [2 ]], [coeffs [3 ], coeffs [4 ], coeffs [5 ]]]], dtype = dtype , device = device
923
923
)
924
924
theta2 = torch .tensor ([[[coeffs [6 ], coeffs [7 ], 1.0 ], [coeffs [6 ], coeffs [7 ], 1.0 ]]], dtype = dtype , device = device )
925
925
926
926
d = 0.5
927
927
base_grid = torch .empty (1 , oh , ow , 3 , dtype = dtype , device = device )
928
- x_grid = torch .linspace (d , ow * 1.0 + d - 1.0 , steps = ow , device = device )
928
+ x_grid = torch .linspace (d , ow + d - 1.0 , steps = ow , device = device )
929
929
base_grid [..., 0 ].copy_ (x_grid )
930
- y_grid = torch .linspace (d , oh * 1.0 + d - 1.0 , steps = oh , device = device ).unsqueeze_ (- 1 )
930
+ y_grid = torch .linspace (d , oh + d - 1.0 , steps = oh , device = device ).unsqueeze_ (- 1 )
931
931
base_grid [..., 1 ].copy_ (y_grid )
932
932
base_grid [..., 2 ].fill_ (1 )
933
933
@@ -1059,6 +1059,7 @@ def perspective_bounding_box(
1059
1059
(- perspective_coeffs [0 ] * perspective_coeffs [7 ] + perspective_coeffs [1 ] * perspective_coeffs [6 ]) / denom ,
1060
1060
]
1061
1061
1062
+ # TODO: should we define them transposed?
1062
1063
theta1 = torch .tensor (
1063
1064
[[inv_coeffs [0 ], inv_coeffs [1 ], inv_coeffs [2 ]], [inv_coeffs [3 ], inv_coeffs [4 ], inv_coeffs [5 ]]],
1064
1065
dtype = dtype ,
@@ -1165,14 +1166,17 @@ def elastic_image_tensor(
1165
1166
return image
1166
1167
1167
1168
shape = image .shape
1169
+ device = image .device
1168
1170
1169
1171
if image .ndim > 4 :
1170
1172
image = image .reshape ((- 1 ,) + shape [- 3 :])
1171
1173
needs_unsquash = True
1172
1174
else :
1173
1175
needs_unsquash = False
1174
1176
1175
- output = _FT .elastic_transform (image , displacement , interpolation = interpolation .value , fill = fill )
1177
+ image_height , image_width = shape [- 2 :]
1178
+ grid = _create_identity_grid ((image_height , image_width ), device = device ).add_ (displacement .to (device ))
1179
+ output = _FT ._apply_grid_transform (image , grid , interpolation .value , fill )
1176
1180
1177
1181
if needs_unsquash :
1178
1182
output = output .reshape (shape )
@@ -1505,8 +1509,7 @@ def five_crop_image_tensor(
1505
1509
image_height , image_width = image .shape [- 2 :]
1506
1510
1507
1511
if crop_width > image_width or crop_height > image_height :
1508
- msg = "Requested crop size {} is bigger than input size {}"
1509
- raise ValueError (msg .format (size , (image_height , image_width )))
1512
+ raise ValueError (f"Requested crop size { size } is bigger than input size { (image_height , image_width )} " )
1510
1513
1511
1514
tl = crop_image_tensor (image , 0 , 0 , crop_height , crop_width )
1512
1515
tr = crop_image_tensor (image , 0 , image_width - crop_width , crop_height , crop_width )
@@ -1525,8 +1528,7 @@ def five_crop_image_pil(
1525
1528
image_height , image_width = get_spatial_size_image_pil (image )
1526
1529
1527
1530
if crop_width > image_width or crop_height > image_height :
1528
- msg = "Requested crop size {} is bigger than input size {}"
1529
- raise ValueError (msg .format (size , (image_height , image_width )))
1531
+ raise ValueError (f"Requested crop size { size } is bigger than input size { (image_height , image_width )} " )
1530
1532
1531
1533
tl = crop_image_pil (image , 0 , 0 , crop_height , crop_width )
1532
1534
tr = crop_image_pil (image , 0 , image_width - crop_width , crop_height , crop_width )
0 commit comments