@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
404
404
405
405
406
406
def _apply_grid_transform (
407
- float_img : torch .Tensor , grid : torch .Tensor , mode : str , fill : datapoints .FillTypeJIT
407
+ img : torch .Tensor , grid : torch .Tensor , mode : str , fill : datapoints .FillTypeJIT
408
408
) -> torch .Tensor :
409
409
410
+ # We are using context knowledge that grid should have float dtype
411
+ fp = img .dtype == grid .dtype
412
+ float_img = img if fp else img .to (grid .dtype )
413
+
410
414
shape = float_img .shape
411
415
if shape [0 ] > 1 :
412
416
# Apply same grid to a batch of images
@@ -433,7 +437,9 @@ def _apply_grid_transform(
433
437
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
434
438
float_img = float_img .sub_ (fill_img ).mul_ (mask ).add_ (fill_img )
435
439
436
- return float_img
440
+ img = float_img .round_ ().to (img .dtype ) if not fp else float_img
441
+
442
+ return img
437
443
438
444
439
445
def _assert_grid_transform_inputs (
@@ -511,7 +517,6 @@ def affine_image_tensor(
511
517
512
518
shape = image .shape
513
519
ndim = image .ndim
514
- fp = torch .is_floating_point (image )
515
520
516
521
if ndim > 4 :
517
522
image = image .reshape ((- 1 ,) + shape [- 3 :])
@@ -535,13 +540,10 @@ def affine_image_tensor(
535
540
536
541
_assert_grid_transform_inputs (image , matrix , interpolation .value , fill , ["nearest" , "bilinear" ])
537
542
538
- dtype = image .dtype if fp else torch .float32
543
+ dtype = image .dtype if torch . is_floating_point ( image ) else torch .float32
539
544
theta = torch .tensor (matrix , dtype = dtype , device = image .device ).reshape (1 , 2 , 3 )
540
545
grid = _affine_grid (theta , w = width , h = height , ow = width , oh = height )
541
- output = _apply_grid_transform (image if fp else image .to (dtype ), grid , interpolation .value , fill = fill )
542
-
543
- if not fp :
544
- output = output .round_ ().to (image .dtype )
546
+ output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
545
547
546
548
if needs_unsquash :
547
549
output = output .reshape (shape )
@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
612
614
# Single point structure is similar to
613
615
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
614
616
points = bounding_box [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]].reshape (- 1 , 2 )
615
- points = torch .cat ([points , torch .ones (points .shape [0 ], 1 , device = points . device )], dim = - 1 )
617
+ points = torch .cat ([points , torch .ones (points .shape [0 ], 1 , device = device , dtype = dtype )], dim = - 1 )
616
618
# 2) Now let's transform the points using affine matrix
617
619
transformed_points = torch .matmul (points , transposed_affine_matrix )
618
620
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
@@ -797,19 +799,15 @@ def rotate_image_tensor(
797
799
matrix = _get_inverse_affine_matrix (center_f , - angle , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ])
798
800
799
801
if image .numel () > 0 :
800
- fp = torch .is_floating_point (image )
801
802
image = image .reshape (- 1 , num_channels , height , width )
802
803
803
804
_assert_grid_transform_inputs (image , matrix , interpolation .value , fill , ["nearest" , "bilinear" ])
804
805
805
806
ow , oh = _compute_affine_output_size (matrix , width , height ) if expand else (width , height )
806
- dtype = image .dtype if fp else torch .float32
807
+ dtype = image .dtype if torch . is_floating_point ( image ) else torch .float32
807
808
theta = torch .tensor (matrix , dtype = dtype , device = image .device ).reshape (1 , 2 , 3 )
808
809
grid = _affine_grid (theta , w = width , h = height , ow = ow , oh = oh )
809
- output = _apply_grid_transform (image if fp else image .to (dtype ), grid , interpolation .value , fill = fill )
810
-
811
- if not fp :
812
- output = output .round_ ().to (image .dtype )
810
+ output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
813
811
814
812
new_height , new_width = output .shape [- 2 :]
815
813
else :
@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
1237
1235
1238
1236
d = 0.5
1239
1237
base_grid = torch .empty (1 , oh , ow , 3 , dtype = dtype , device = device )
1240
- x_grid = torch .linspace (d , ow + d - 1.0 , steps = ow , device = device )
1238
+ x_grid = torch .linspace (d , ow + d - 1.0 , steps = ow , device = device , dtype = dtype )
1241
1239
base_grid [..., 0 ].copy_ (x_grid )
1242
- y_grid = torch .linspace (d , oh + d - 1.0 , steps = oh , device = device ).unsqueeze_ (- 1 )
1240
+ y_grid = torch .linspace (d , oh + d - 1.0 , steps = oh , device = device , dtype = dtype ).unsqueeze_ (- 1 )
1243
1241
base_grid [..., 1 ].copy_ (y_grid )
1244
1242
base_grid [..., 2 ].fill_ (1 )
1245
1243
@@ -1283,7 +1281,6 @@ def perspective_image_tensor(
1283
1281
1284
1282
shape = image .shape
1285
1283
ndim = image .ndim
1286
- fp = torch .is_floating_point (image )
1287
1284
1288
1285
if ndim > 4 :
1289
1286
image = image .reshape ((- 1 ,) + shape [- 3 :])
@@ -1304,12 +1301,9 @@ def perspective_image_tensor(
1304
1301
)
1305
1302
1306
1303
oh , ow = shape [- 2 :]
1307
- dtype = image .dtype if fp else torch .float32
1304
+ dtype = image .dtype if torch . is_floating_point ( image ) else torch .float32
1308
1305
grid = _perspective_grid (perspective_coeffs , ow = ow , oh = oh , dtype = dtype , device = image .device )
1309
- output = _apply_grid_transform (image if fp else image .to (dtype ), grid , interpolation .value , fill = fill )
1310
-
1311
- if not fp :
1312
- output = output .round_ ().to (image .dtype )
1306
+ output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
1313
1307
1314
1308
if needs_unsquash :
1315
1309
output = output .reshape (shape )
@@ -1494,8 +1488,12 @@ def elastic_image_tensor(
1494
1488
1495
1489
shape = image .shape
1496
1490
ndim = image .ndim
1491
+
1497
1492
device = image .device
1498
- fp = torch .is_floating_point (image )
1493
+ dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
1494
+ # We are aware that if input image dtype is uint8 and displacement is float64 then
1495
+ # displacement will be casted to float32 and all computations will be done with float32
1496
+ # We can fix this later if needed
1499
1497
1500
1498
if ndim > 4 :
1501
1499
image = image .reshape ((- 1 ,) + shape [- 3 :])
@@ -1506,12 +1504,12 @@ def elastic_image_tensor(
1506
1504
else :
1507
1505
needs_unsquash = False
1508
1506
1509
- image_height , image_width = shape [- 2 :]
1510
- grid = _create_identity_grid ((image_height , image_width ), device = device ).add_ (displacement .to (device ))
1511
- output = _apply_grid_transform (image if fp else image .to (torch .float32 ), grid , interpolation .value , fill = fill )
1507
+ if displacement .dtype != dtype or displacement .device != device :
1508
+ displacement = displacement .to (dtype = dtype , device = device )
1512
1509
1513
- if not fp :
1514
- output = output .round_ ().to (image .dtype )
1510
+ image_height , image_width = shape [- 2 :]
1511
+ grid = _create_identity_grid ((image_height , image_width ), device = device , dtype = dtype ).add_ (displacement )
1512
+ output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
1515
1513
1516
1514
if needs_unsquash :
1517
1515
output = output .reshape (shape )
@@ -1531,13 +1529,13 @@ def elastic_image_pil(
1531
1529
return to_pil_image (output , mode = image .mode )
1532
1530
1533
1531
1534
- def _create_identity_grid (size : Tuple [int , int ], device : torch .device ) -> torch .Tensor :
1532
+ def _create_identity_grid (size : Tuple [int , int ], device : torch .device , dtype : torch . dtype ) -> torch .Tensor :
1535
1533
sy , sx = size
1536
- base_grid = torch .empty (1 , sy , sx , 2 , device = device )
1537
- x_grid = torch .linspace ((- sx + 1 ) / sx , (sx - 1 ) / sx , sx , device = device )
1534
+ base_grid = torch .empty (1 , sy , sx , 2 , device = device , dtype = dtype )
1535
+ x_grid = torch .linspace ((- sx + 1 ) / sx , (sx - 1 ) / sx , sx , device = device , dtype = dtype )
1538
1536
base_grid [..., 0 ].copy_ (x_grid )
1539
1537
1540
- y_grid = torch .linspace ((- sy + 1 ) / sy , (sy - 1 ) / sy , sy , device = device ).unsqueeze_ (- 1 )
1538
+ y_grid = torch .linspace ((- sy + 1 ) / sy , (sy - 1 ) / sy , sy , device = device , dtype = dtype ).unsqueeze_ (- 1 )
1541
1539
base_grid [..., 1 ].copy_ (y_grid )
1542
1540
1543
1541
return base_grid
@@ -1552,7 +1550,11 @@ def elastic_bounding_box(
1552
1550
return bounding_box
1553
1551
1554
1552
# TODO: add in docstring about approximation we are doing for grid inversion
1555
- displacement = displacement .to (bounding_box .device )
1553
+ device = bounding_box .device
1554
+ dtype = bounding_box .dtype if torch .is_floating_point (bounding_box ) else torch .float32
1555
+
1556
+ if displacement .dtype != dtype or displacement .device != device :
1557
+ displacement = displacement .to (dtype = dtype , device = device )
1556
1558
1557
1559
original_shape = bounding_box .shape
1558
1560
bounding_box = (
@@ -1563,7 +1565,7 @@ def elastic_bounding_box(
1563
1565
# Or add spatial_size arg and check displacement shape
1564
1566
spatial_size = displacement .shape [- 3 ], displacement .shape [- 2 ]
1565
1567
1566
- id_grid = _create_identity_grid (spatial_size , bounding_box . device )
1568
+ id_grid = _create_identity_grid (spatial_size , device = device , dtype = dtype )
1567
1569
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
1568
1570
# This is not an exact inverse of the grid
1569
1571
inv_grid = id_grid .sub_ (displacement )
0 commit comments