Skip to content

Commit 3ed27a9

Browse files
malfetfacebook-github-bot
authored andcommitted
[BE] Refactor repetitions into TorchVersion._cmp_wrapper` (pytorch#71344)
Summary: First step towards pytorch#71280 Pull Request resolved: pytorch#71344 Reviewed By: b0noI Differential Revision: D33594463 Pulled By: malfet fbshipit-source-id: 0295f0d9f0342f05a390b2bd4aa0a5958c76579b
1 parent c43e028 commit 3ed27a9

File tree

1 file changed

+7
-36
lines changed

1 file changed

+7
-36
lines changed

torch/torch_version.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, Union
1+
from typing import Any, Iterable
22

33
from pkg_resources import packaging # type: ignore[attr-defined]
44

@@ -25,8 +25,7 @@ class TorchVersion(str):
2525
TorchVersion('1.10.0a') > '1.2'
2626
TorchVersion('1.10.0a') > '1.2.1'
2727
"""
28-
# fully qualified type names here to appease mypy
29-
def _convert_to_version(self, inp: Union[packaging.version.Version, str, Iterable]) -> packaging.version.Version:
28+
def _convert_to_version(self, inp: Any) -> Any:
3029
if isinstance(inp, Version):
3130
return inp
3231
elif isinstance(inp, str):
@@ -42,44 +41,16 @@ def _convert_to_version(self, inp: Union[packaging.version.Version, str, Iterabl
4241
else:
4342
raise InvalidVersion(inp)
4443

45-
def __gt__(self, cmp):
44+
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
4645
try:
47-
return Version(self).__gt__(self._convert_to_version(cmp))
46+
return getattr(Version(self), method)(self._convert_to_version(cmp))
4847
except InvalidVersion:
4948
# Fall back to regular string comparison if dealing with an invalid
5049
# version like 'parrot'
51-
return super().__gt__(cmp)
50+
return getattr(super(), method)(cmp)
5251

53-
def __lt__(self, cmp):
54-
try:
55-
return Version(self).__lt__(self._convert_to_version(cmp))
56-
except InvalidVersion:
57-
# Fall back to regular string comparison if dealing with an invalid
58-
# version like 'parrot'
59-
return super().__lt__(cmp)
60-
61-
def __eq__(self, cmp):
62-
try:
63-
return Version(self).__eq__(self._convert_to_version(cmp))
64-
except InvalidVersion:
65-
# Fall back to regular string comparison if dealing with an invalid
66-
# version like 'parrot'
67-
return super().__eq__(cmp)
6852

69-
def __ge__(self, cmp):
70-
try:
71-
return Version(self).__ge__(self._convert_to_version(cmp))
72-
except InvalidVersion:
73-
# Fall back to regular string comparison if dealing with an invalid
74-
# version like 'parrot'
75-
return super().__ge__(cmp)
76-
77-
def __le__(self, cmp):
78-
try:
79-
return Version(self).__le__(self._convert_to_version(cmp))
80-
except InvalidVersion:
81-
# Fall back to regular string comparison if dealing with an invalid
82-
# version like 'parrot'
83-
return super().__le__(cmp)
53+
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
54+
setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method))
8455

8556
__version__ = TorchVersion(internal_version)

0 commit comments

Comments
 (0)