From 15b075668db7b258285d88ded5c08b7c28d6a747 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 6 Jun 2024 10:43:16 +0100 Subject: [PATCH 1/2] Remove broken MPS build --- setup.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index fedbc370f72..74f94fc95e7 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ import shutil import subprocess import sys +import warnings import torch from pkg_resources import DistributionNotFound, get_distribution, parse_version @@ -204,8 +205,14 @@ def get_extensions(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags - elif torch.backends.mps.is_available() or force_mps: - sources += source_mps + + # FIXME: MPS build breaks custom ops registration, so it was disabled. + # See https://github.com/pytorch/vision/issues/8456. + # TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V + if force_mps: + warnings.warn("MPS build is temporarily disabled!!!!") + # elif torch.backends.mps.is_available() or force_mps: + # sources += source_mps if sys.platform == "win32": define_macros += [("torchvision_EXPORTS", None)] From 40c71ed3f55a12d0397c487a30b03ac89dfc5e29 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 6 Jun 2024 10:48:20 +0100 Subject: [PATCH 2/2] lint --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 74f94fc95e7..753a50ffeed 100644 --- a/setup.py +++ b/setup.py @@ -139,7 +139,6 @@ def get_extensions(): + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) - source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) print("Compiling extensions with following flags:") force_cuda = os.getenv("FORCE_CUDA", "0") == "1" @@ -212,6 +211,7 @@ def get_extensions(): if force_mps: warnings.warn("MPS build is temporarily disabled!!!!") # elif torch.backends.mps.is_available() or force_mps: + # source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) # sources += source_mps if sys.platform == "win32":