1
- from typing import Iterable , Union
1
+ from typing import Any , Iterable
2
2
3
3
from pkg_resources import packaging # type: ignore[attr-defined]
4
4
@@ -25,8 +25,7 @@ class TorchVersion(str):
25
25
TorchVersion('1.10.0a') > '1.2'
26
26
TorchVersion('1.10.0a') > '1.2.1'
27
27
"""
28
- # fully qualified type names here to appease mypy
29
- def _convert_to_version (self , inp : Union [packaging .version .Version , str , Iterable ]) -> packaging .version .Version :
28
+ def _convert_to_version (self , inp : Any ) -> Any :
30
29
if isinstance (inp , Version ):
31
30
return inp
32
31
elif isinstance (inp , str ):
@@ -42,44 +41,16 @@ def _convert_to_version(self, inp: Union[packaging.version.Version, str, Iterabl
42
41
else :
43
42
raise InvalidVersion (inp )
44
43
45
- def __gt__ (self , cmp ) :
44
+ def _cmp_wrapper (self , cmp : Any , method : str ) -> bool :
46
45
try :
47
- return Version (self ). __gt__ (self ._convert_to_version (cmp ))
46
+ return getattr ( Version (self ), method ) (self ._convert_to_version (cmp ))
48
47
except InvalidVersion :
49
48
# Fall back to regular string comparison if dealing with an invalid
50
49
# version like 'parrot'
51
- return super (). __gt__ (cmp )
50
+ return getattr ( super (), method ) (cmp )
52
51
53
- def __lt__ (self , cmp ):
54
- try :
55
- return Version (self ).__lt__ (self ._convert_to_version (cmp ))
56
- except InvalidVersion :
57
- # Fall back to regular string comparison if dealing with an invalid
58
- # version like 'parrot'
59
- return super ().__lt__ (cmp )
60
-
61
- def __eq__ (self , cmp ):
62
- try :
63
- return Version (self ).__eq__ (self ._convert_to_version (cmp ))
64
- except InvalidVersion :
65
- # Fall back to regular string comparison if dealing with an invalid
66
- # version like 'parrot'
67
- return super ().__eq__ (cmp )
68
52
69
- def __ge__ (self , cmp ):
70
- try :
71
- return Version (self ).__ge__ (self ._convert_to_version (cmp ))
72
- except InvalidVersion :
73
- # Fall back to regular string comparison if dealing with an invalid
74
- # version like 'parrot'
75
- return super ().__ge__ (cmp )
76
-
77
- def __le__ (self , cmp ):
78
- try :
79
- return Version (self ).__le__ (self ._convert_to_version (cmp ))
80
- except InvalidVersion :
81
- # Fall back to regular string comparison if dealing with an invalid
82
- # version like 'parrot'
83
- return super ().__le__ (cmp )
53
+ for cmp_method in ["__gt__" , "__lt__" , "__eq__" , "__ge__" , "__le__" ]:
54
+ setattr (TorchVersion , cmp_method , lambda x , y , method = cmp_method : x ._cmp_wrapper (y , method ))
84
55
85
56
__version__ = TorchVersion (internal_version )
0 commit comments