diff --git a/setup.py b/setup.py index fedbc370f72..753a50ffeed 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ import shutil import subprocess import sys +import warnings import torch from pkg_resources import DistributionNotFound, get_distribution, parse_version @@ -138,7 +139,6 @@ def get_extensions(): + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) - source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) print("Compiling extensions with following flags:") force_cuda = os.getenv("FORCE_CUDA", "0") == "1" @@ -204,8 +204,15 @@ def get_extensions(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags - elif torch.backends.mps.is_available() or force_mps: - sources += source_mps + + # FIXME: MPS build breaks custom ops registration, so it was disabled. + # See https://github.com/pytorch/vision/issues/8456. + # TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V + if force_mps: + warnings.warn("MPS build is temporarily disabled!!!!") + # elif torch.backends.mps.is_available() or force_mps: + # source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) + # sources += source_mps if sys.platform == "win32": define_macros += [("torchvision_EXPORTS", None)]