Skip to content

Commit b217165

Browse files
authored
Add compatibility checks for C++ extensions (#2467)
* Add compatibility checks for C++ extensions * Fix lint
1 parent 2cc20d7 commit b217165

File tree

7 files changed

+32
-0
lines changed

7 files changed

+32
-0
lines changed

torchvision/extension.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
_HAS_OPS = False
22

33

4+
def _has_ops():
5+
return False
6+
7+
48
def _register_extensions():
59
import os
610
import importlib
@@ -23,10 +27,26 @@ def _register_extensions():
2327
try:
2428
_register_extensions()
2529
_HAS_OPS = True
30+
31+
def _has_ops(): # noqa: F811
32+
return True
2633
except (ImportError, OSError):
2734
pass
2835

2936

37+
def _assert_has_ops():
38+
if not _has_ops():
39+
raise RuntimeError(
40+
"Couldn't load custom C++ ops. This can happen if your PyTorch and "
41+
"torchvision versions are incompatible, or if you had errors while compiling "
42+
"torchvision from source. For further information on the compatible versions, check "
43+
"https://github.com/pytorch/vision#installation for the compatibility matrix. "
44+
"Please check your PyTorch version with torch.__version__ and your torchvision "
45+
"version with torchvision.__version__ and verify if they are compatible, and if not "
46+
"please reinstall torchvision so that it matches your PyTorch install."
47+
)
48+
49+
3050
def _check_cuda_version():
3151
"""
3252
Make sure that CUDA versions match between the pytorch install and torchvision install

torchvision/ops/boxes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch import Tensor
44
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
55
import torchvision
6+
from torchvision.extension import _assert_has_ops
67

78

89
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
@@ -37,6 +38,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
3738
of the elements that have been kept
3839
by NMS, sorted in decreasing order of scores
3940
"""
41+
_assert_has_ops()
4042
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
4143

4244

torchvision/ops/deform_conv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn.parameter import Parameter
77
from torch.nn.modules.utils import _pair
88
from torch.jit.annotations import Optional, Tuple
9+
from torchvision.extension import _assert_has_ops
910

1011

1112
def deform_conv2d(
@@ -51,6 +52,7 @@ def deform_conv2d(
5152
>>> torch.Size([4, 5, 8, 8])
5253
"""
5354

55+
_assert_has_ops()
5456
out_channels = weight.shape[0]
5557
if bias is None:
5658
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)

torchvision/ops/ps_roi_align.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List, Tuple
66

7+
from torchvision.extension import _assert_has_ops
78
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
89

910

@@ -38,6 +39,7 @@ def ps_roi_align(
3839
Returns:
3940
output (Tensor[K, C, output_size[0], output_size[1]])
4041
"""
42+
_assert_has_ops()
4143
check_roi_boxes_shape(boxes)
4244
rois = boxes
4345
output_size = _pair(output_size)

torchvision/ops/ps_roi_pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List, Tuple
66

7+
from torchvision.extension import _assert_has_ops
78
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
89

910

@@ -32,6 +33,7 @@ def ps_roi_pool(
3233
Returns:
3334
output (Tensor[K, C, output_size[0], output_size[1]])
3435
"""
36+
_assert_has_ops()
3537
check_roi_boxes_shape(boxes)
3638
rois = boxes
3739
output_size = _pair(output_size)

torchvision/ops/roi_align.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List, BroadcastingList2
66

7+
from torchvision.extension import _assert_has_ops
78
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
89

910

@@ -41,6 +42,7 @@ def roi_align(
4142
Returns:
4243
output (Tensor[K, C, output_size[0], output_size[1]])
4344
"""
45+
_assert_has_ops()
4446
check_roi_boxes_shape(boxes)
4547
rois = boxes
4648
output_size = _pair(output_size)

torchvision/ops/roi_pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List, BroadcastingList2
66

7+
from torchvision.extension import _assert_has_ops
78
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
89

910

@@ -31,6 +32,7 @@ def roi_pool(
3132
Returns:
3233
output (Tensor[K, C, output_size[0], output_size[1]])
3334
"""
35+
_assert_has_ops()
3436
check_roi_boxes_shape(boxes)
3537
rois = boxes
3638
output_size = _pair(output_size)

0 commit comments

Comments
 (0)