Skip to content

Add ROCm5.2/AMDGPU support for PyTorch 1.10 #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 59 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ad22804
[release/1.10] Pin builder and xla repo (#65433)
malfet Sep 21, 2021
0e857bf
[1.10] Remove torch.vmap (#65496)
zou3519 Sep 24, 2021
c05547f
Fix test reporting git merge-base (#65787)
zhouzhuojie Sep 28, 2021
1fa17a2
Fix the slowdown of _object_to_tensor since 1.9 (#65721) (#65835)
malfet Sep 29, 2021
13666d2
[DataPipe] Fix deepcopy filehandle for Mapper and in-place modificati…
ejguan Oct 1, 2021
6aadfda
[ci] try installing libgnutls to fix cert error (#65934) (#65979)
malfet Oct 1, 2021
ecfcb8f
Binary building wthout python fix (#66031) (#66117)
n-v-k Oct 5, 2021
4731f33
Fix Windows ninja builds when MAX_JOBS is specified (#65444) (#66155)
malfet Oct 5, 2021
5f3eee1
Fix backward compatibility tests (#66186)
malfet Oct 6, 2021
2b46c95
[iOS][CI] Update dev certs (#66004) (#66188)
malfet Oct 6, 2021
4e3ebeb
[DataPipe] DataPipe Fix and Deprecation Warnings for Release 1.10 (#6…
NivekT Oct 6, 2021
ecbf5a7
Tweak `file_diff_from_base` for release/1.10 branch (#66202)
malfet Oct 6, 2021
5f1a434
Added option to update parameters using state_dict in AveragedModel (…
prabhat00155 Oct 6, 2021
49f52b6
Revert "Added option to update parameters using state_dict in Average…
prabhat00155 Oct 8, 2021
a27906c
Convert Sampler back to lazily construction (#63646) (#65926)
ejguan Oct 8, 2021
1774a6a
[ONNX] Deprecate various args (#65962)
garymm Oct 8, 2021
9509e8a
Fix cosine similarity dim checks (#66214)
Oct 8, 2021
c3ea586
fix normal with empty std (#66524)
Oct 14, 2021
4a514dd
Call `PyArray_Check` only if NumPy is available (#66433) (#66629)
malfet Oct 14, 2021
3c134b8
Disable .numpy() and .tolist() for tensor subclasses subclasses and f…
anjali411 Oct 14, 2021
cc360fa
Delete extraneous whitespaces
malfet Oct 14, 2021
ddf3092
Disable .numpy() and .tolist() for tensor subclasses subclasses and f…
anjali411 Oct 14, 2021
b544cbd
Handle shared memory cases in MathBitFallback (#66667)
malfet Oct 15, 2021
36449ea
(torch/elastic) add fqdn hostname to error printout (#66182) (#66662)
kiukchung Oct 15, 2021
03c7e47
Use preview version of kineto with roctracer support
mwootton Nov 1, 2021
f9456e2
Update kineto commit
mwootton Nov 1, 2021
eb539d9
related commits for apex and torchvision
jeffdaily Nov 1, 2021
bb47ac7
rocblas alt impl during backward pass only
jeffdaily Nov 9, 2021
9ceebba
Revert "Replace `internal::GRAIN_SIZE` by `grain_size` (parameter). (…
pruthvistony Dec 5, 2021
af7d4f0
Add amdgpu repos for ROCm4.5 install (#886)
jithunnair-amd Dec 11, 2021
b8491d6
Use --no-check-certificate flag for valgrind install in CentOS7 only …
jithunnair-amd Oct 6, 2021
15e0a12
Update gfx archs
jithunnair-amd May 12, 2021
d839e39
add support for ubuntu 20.04 to CI docker images
jeffdaily Oct 19, 2021
c3fedb2
Set variables for base urls to allow configurability
jithunnair-amd Dec 14, 2021
f6ced82
[ROCm] use hipCUB chained iterator
jaglinux Jan 10, 2022
fd60586
Merge pull request #891 from jaglinux/rocprim_chained_iterators_1.10
jithunnair-amd Jan 10, 2022
0f78002
Export c10::hip::HIPCachingAllocatorMasqueradingAsCUDA::get() in libt…
jithunnair-amd Jan 20, 2022
613a636
Disable caffe2 build (#901)
jithunnair-amd Feb 5, 2022
089849b
Add ROCm5.0/AMDGPU support (#904)
WBobby Feb 10, 2022
7bb56db
Fix JIT path for Pytorch extensions and other hipify fixes (#903)
jithunnair-amd Feb 14, 2022
599891e
Cherry-pick the commit to make TORCH_(CUDABLAS|CUSOLVER)_CHECK usable…
hubertlu-tw Feb 16, 2022
8f1516f
Add AMDGPU version for ROCm5.0.1
jithunnair-amd Feb 16, 2022
5ef474f
Hipify bug fix for header_include_paths being passed in as None from …
jithunnair-amd Feb 19, 2022
fbe849f
Remove gfx1030 from list of default targets for PyTorch since Navi21 …
jithunnair-amd Mar 10, 2022
c12b6a2
[WIP][resubmit] Don't #define NUM_THREADS (#68008)
zasdfgbnm Nov 9, 2021
da64fe2
Add amdgpu version support for ROCm5.1 (#980)
WBobby Mar 29, 2022
3414a8e
Add ROCm5.1.1/AMDGPU support (#985)
WBobby Apr 7, 2022
539d476
[ROCm] revert cat operator performance work-around (#987)
jeffdaily Apr 8, 2022
cf85f6b
Enable atomicAddNoRet() for all gfx targets (#992)
rraminen Apr 11, 2022
3af6016
Updated handling of PYTORCH_ROCM_ARCH while building base docker
pruthvistony Apr 19, 2022
1f19f03
Properly import LooseVersion (#996)
jeffdaily Apr 14, 2022
ed9b160
[ROCm] use ncclAllToAll for rocm
KyleCZH Apr 22, 2022
4a5e2b0
Deactive ncclAllToAll since degradation was observed on a Hayabusa sy…
liligwu May 18, 2022
85ace43
Merge pull request #1011 from ROCmSoftwarePlatform/release/1.10_rever…
liligwu May 18, 2022
5fa4b1f
Add ROCm5.1.3/AMDGPU support
yanyao-wang May 20, 2022
3c9bd05
Merge pull request #1017 from WBobby/release/1.10
sunway513 May 21, 2022
0352824
[ROCm] update cmake package DIR paths (#77087)
jeffdaily May 10, 2022
e301e9b
Add ROCm5.2/AMDGPU support for PyTorch 1.10
yanyao-wang Jun 27, 2022
5e03866
Add ROCm5.2/AMDGPU support for PyTorch 1.10
yanyao-wang Jul 14, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .circleci/docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ docker build \
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
--build-arg "KATEX=${KATEX:-}" \
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx900;gfx906;gfx908}" \
-f $(dirname ${DOCKERFILE})/Dockerfile \
-t "$tmp_tag" \
"$@" \
Expand Down
13 changes: 10 additions & 3 deletions .circleci/docker/common/install_base.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ install_ubuntu() {
# "$UBUNTU_VERSION" == "18.04"
if [[ "$UBUNTU_VERSION" == "18.04"* ]]; then
cmake3="cmake=3.10*"
maybe_libiomp_dev="libiomp-dev"
elif [[ "$UBUNTU_VERSION" == "20.04"* ]]; then
cmake3="cmake=3.16*"
maybe_libiomp_dev=""
else
cmake3="cmake=3.5*"
maybe_libiomp_dev="libiomp-dev"
fi

# Install common dependencies
Expand All @@ -33,7 +38,7 @@ install_ubuntu() {
git \
libatlas-base-dev \
libc6-dbg \
libiomp-dev \
${maybe_libiomp_dev} \
libyaml-dev \
libz-dev \
libjpeg-dev \
Expand Down Expand Up @@ -107,11 +112,13 @@ case "$ID" in
esac

# Install Valgrind separately since the apt-get version is too old.
OS_VERSION=$(grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"')
if [[ $ID == centos && $OS_VERSION == 7 ]]; then WGET_FLAG="--no-check-certificate" ; else WGET_FLAG=""; fi
mkdir valgrind_build && cd valgrind_build
VALGRIND_VERSION=3.16.1
if ! wget http://valgrind.org/downloads/valgrind-${VALGRIND_VERSION}.tar.bz2
if ! wget $WGET_FLAG http://valgrind.org/downloads/valgrind-${VALGRIND_VERSION}.tar.bz2
then
wget https://sourceware.org/ftp/valgrind/valgrind-${VALGRIND_VERSION}.tar.bz2
wget $WGET_FLAG https://sourceware.org/ftp/valgrind/valgrind-${VALGRIND_VERSION}.tar.bz2
fi
tar -xjf valgrind-${VALGRIND_VERSION}.tar.bz2
cd valgrind-${VALGRIND_VERSION}
Expand Down
36 changes: 32 additions & 4 deletions .circleci/docker/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ install_magma() {
cp make.inc-examples/make.inc.hip-gcc-mkl make.inc
echo 'LIBDIR += -L$(MKLROOT)/lib' >> make.inc
echo 'LIB += -Wl,--enable-new-dtags -Wl,--rpath,/opt/rocm/lib -Wl,--rpath,$(MKLROOT)/lib -Wl,--rpath,/opt/rocm/magma/lib' >> make.inc
echo 'DEVCCFLAGS += --amdgpu-target=gfx803 --amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --gpu-max-threads-per-block=256' >> make.inc
echo 'DEVCCFLAGS += --amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --amdgpu-target=gfx90a --gpu-max-threads-per-block=256' >> make.inc
# hipcc with openmp flag may cause isnan() on __device__ not to be found; depending on context, compiler may attempt to match with host definition
sed -i 's/^FOPENMP/#FOPENMP/g' make.inc
export PATH="${PATH}:/opt/rocm/bin"
Expand All @@ -29,27 +29,42 @@ ver() {
printf "%3d%03d%03d%03d" $(echo "$1" | tr '.' ' ');
}

# Map ROCm version to AMDGPU version
declare -A AMDGPU_VERSIONS=( ["4.5.2"]="21.40.2" ["5.0"]="21.50" ["5.0.1"]="21.50.1" ["5.1"]="22.10" ["5.1.1"]="22.10.1" ["5.1.3"]="22.10.3" ["5.2"]="22.20" ["5.2.1"]="22.20.1" )

install_ubuntu() {
apt-get update
if [[ $UBUNTU_VERSION == 18.04 ]]; then
# gpg-agent is not available by default on 18.04
apt-get install -y --no-install-recommends gpg-agent
fi
if [[ $UBUNTU_VERSION == 20.04 ]]; then
# gpg-agent is not available by default on 20.04
apt-get install -y --no-install-recommends gpg-agent
fi
apt-get install -y kmod
apt-get install -y wget

# Need the libc++1 and libc++abi1 libraries to allow torch._C to load at runtime
apt-get install -y libc++1
apt-get install -y libc++abi1

if [[ $(ver $ROCM_VERSION) -ge $(ver 4.5) ]]; then
# Add amdgpu repository
UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
local amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/ubuntu"
echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
fi

ROCM_REPO="ubuntu"
if [[ $(ver $ROCM_VERSION) -lt $(ver 4.2) ]]; then
ROCM_REPO="xenial"
fi

# Add rocm repository
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
echo "deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION} ${ROCM_REPO} main" > /etc/apt/sources.list.d/rocm.list
local rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}"
echo "deb [arch=amd64] ${rocm_baseurl} ${ROCM_REPO} main" > /etc/apt/sources.list.d/rocm.list
apt-get update --allow-insecure-repositories

DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
Expand Down Expand Up @@ -86,11 +101,24 @@ install_centos() {
yum install -y epel-release
yum install -y dkms kernel-headers-`uname -r` kernel-devel-`uname -r`

if [[ $(ver $ROCM_VERSION) -ge $(ver 4.5) ]]; then
# Add amdgpu repository
local amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/rhel/7.9/main/x86_64"
echo "[AMDGPU]" > /etc/yum.repos.d/amdgpu.repo
echo "name=AMDGPU" >> /etc/yum.repos.d/amdgpu.repo
echo "baseurl=${amdgpu_baseurl}" >> /etc/yum.repos.d/amdgpu.repo
echo "enabled=1" >> /etc/yum.repos.d/amdgpu.repo
echo "gpgcheck=1" >> /etc/yum.repos.d/amdgpu.repo
echo "gpgkey=http://repo.radeon.com/rocm/rocm.gpg.key" >> /etc/yum.repos.d/amdgpu.repo
fi

local rocm_baseurl="http://repo.radeon.com/rocm/yum/${ROCM_VERSION}"
echo "[ROCm]" > /etc/yum.repos.d/rocm.repo
echo "name=ROCm" >> /etc/yum.repos.d/rocm.repo
echo "baseurl=http://repo.radeon.com/rocm/yum/${ROCM_VERSION}" >> /etc/yum.repos.d/rocm.repo
echo "baseurl=${rocm_baseurl}" >> /etc/yum.repos.d/rocm.repo
echo "enabled=1" >> /etc/yum.repos.d/rocm.repo
echo "gpgcheck=0" >> /etc/yum.repos.d/rocm.repo
echo "gpgcheck=1" >> /etc/yum.repos.d/rocm.repo
echo "gpgkey=http://repo.radeon.com/rocm/rocm.gpg.key" >> /etc/yum.repos.d/rocm.repo

yum update -y

Expand Down
2 changes: 1 addition & 1 deletion .circleci/scripts/binary_checkout.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ git --no-pager log --max-count 1
popd

# Clone the Builder master repo
retry git clone -q https://github.com/pytorch/builder.git "$BUILDER_ROOT"
retry git clone -q https://github.com/pytorch/builder.git -b release/1.10 "$BUILDER_ROOT"
pushd "$BUILDER_ROOT"
echo "Using builder from "
git --no-pager log --max-count 1
Expand Down
11 changes: 6 additions & 5 deletions .circleci/scripts/binary_ios_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@ cd ${PROJ_ROOT}/ios/TestApp
# install fastlane
sudo gem install bundler && bundle install
# install certificates
echo "${IOS_CERT_KEY}" >> cert.txt
echo "${IOS_CERT_KEY_2022}" >> cert.txt
base64 --decode cert.txt -o Certificates.p12
rm cert.txt
bundle exec fastlane install_cert
bundle exec fastlane install_root_cert
bundle exec fastlane install_dev_cert
# install the provisioning profile
PROFILE=PyTorch_CI_2021.mobileprovision
PROFILE=PyTorch_CI_2022.mobileprovision
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
mkdir -pv "${PROVISIONING_PROFILES}"
cd "${PROVISIONING_PROFILES}"
echo "${IOS_SIGN_KEY}" >> cert.txt
echo "${IOS_SIGN_KEY_2022}" >> cert.txt
base64 --decode cert.txt -o ${PROFILE}
rm cert.txt
# run the ruby build script
if ! [ -x "$(command -v xcodebuild)" ]; then
echo 'Error: xcodebuild is not installed.'
exit 1
fi
PROFILE=PyTorch_CI_2021
PROFILE=PyTorch_CI_2022
ruby ${PROJ_ROOT}/scripts/xcode_build.rb -i ${PROJ_ROOT}/build_ios/install -x ${PROJ_ROOT}/ios/TestApp/TestApp.xcodeproj -p ${IOS_PLATFORM} -c ${PROFILE} -t ${IOS_DEV_TEAM_ID} -f Accelerate,MetalPerformanceShaders,CoreML
11 changes: 6 additions & 5 deletions .circleci/verbatim-sources/job-specs/job-specs-custom.yml
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,17 @@
# install fastlane
sudo gem install bundler && bundle install
# install certificates
echo ${IOS_CERT_KEY} >> cert.txt
echo ${IOS_CERT_KEY_2022} >> cert.txt
base64 --decode cert.txt -o Certificates.p12
rm cert.txt
bundle exec fastlane install_cert
bundle exec fastlane install_root_cert
bundle exec fastlane install_dev_cert
# install the provisioning profile
PROFILE=PyTorch_CI_2021.mobileprovision
PROFILE=PyTorch_CI_2022.mobileprovision
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
mkdir -pv "${PROVISIONING_PROFILES}"
cd "${PROVISIONING_PROFILES}"
echo ${IOS_SIGN_KEY} >> cert.txt
echo ${IOS_SIGN_KEY_2022} >> cert.txt
base64 --decode cert.txt -o ${PROFILE}
rm cert.txt
- run:
Expand Down Expand Up @@ -535,7 +536,7 @@
command: |
set -e
PROJ_ROOT=/Users/distiller/project
PROFILE=PyTorch_CI_2021
PROFILE=PyTorch_CI_2022
# run the ruby build script
if ! [ -x "$(command -v xcodebuild)" ]; then
echo 'Error: xcodebuild is not installed.'
Expand Down
3 changes: 2 additions & 1 deletion .circleci/verbatim-sources/job-specs/pytorch-job-specs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ jobs:
}

if is_vanilla_build; then
echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
else
echo "Skipping for ${BUILD_ENVIRONMENT}"
Expand Down
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "third_party/kineto"]
path = third_party/kineto
url = https://github.com/pytorch/kineto
url = https://github.com/mwootton/kineto
[submodule "third_party/pocketfft"]
path = third_party/pocketfft
url = https://github.com/mreineck/pocketfft
Expand Down
6 changes: 3 additions & 3 deletions .jenkins/pytorch/common_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ function get_pr_change_files() {
function file_diff_from_base() {
# The fetch may fail on Docker hosts, this fetch is necessary for GHA
set +e
git fetch origin master --quiet
git fetch origin release/1.10 --quiet
set -e
git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
git diff --name-only "$(git merge-base origin/release/1.10 HEAD)" > "$1"
}

function get_bazel() {
Expand Down Expand Up @@ -99,5 +99,5 @@ function checkout_install_torchvision() {
}

function clone_pytorch_xla() {
git clone --recursive https://github.com/pytorch/xla.git
git clone --recursive -b r1.10 https://github.com/pytorch/xla.git
}
4 changes: 3 additions & 1 deletion .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ test_aten() {
test_without_numpy() {
pushd "$(dirname "${BASH_SOURCE[0]}")"
python -c "import sys;sys.path.insert(0, 'fake_numpy');from unittest import TestCase;import torch;x=torch.randn(3,3);TestCase().assertRaises(RuntimeError, lambda: x.numpy())"
# Regression test for https://github.com/pytorch/pytorch/issues/66353
python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch;print(torch.tensor([torch.tensor(0.), torch.tensor(1.)]))"
popd
}

Expand Down Expand Up @@ -424,7 +426,7 @@ test_backward_compatibility() {
python -m venv venv
# shellcheck disable=SC1091
. venv/bin/activate
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_nightly.html
pip show torch
python dump_all_function_schemas.py --filename nightly_schemas.txt
deactivate
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ option(BUILD_BINARY "Build C++ binaries" OFF)
option(BUILD_DOCS "Build Caffe2 documentation" OFF)
option(BUILD_CUSTOM_PROTOBUF "Build and use Caffe2's own protobuf under third_party" ON)
option(BUILD_PYTHON "Build Python binaries" ON)
option(BUILD_CAFFE2 "Master flag to build Caffe2" ON)
option(BUILD_CAFFE2 "Master flag to build Caffe2" OFF)
option(BUILD_LITE_INTERPRETER "Master flag to build Lite Interpreter" OFF)
cmake_dependent_option(
BUILD_CAFFE2_OPS "Build Caffe2 operators" ON
Expand Down
12 changes: 2 additions & 10 deletions aten/src/ATen/ConjugateFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@
#include <ATen/native/MathBitFallThroughLists.h>

namespace at {

namespace native {
struct ConjFallback : MathOpFallback {
ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate") {}
bool is_bit_set(const Tensor& tensor) override {
return tensor.is_conj();
}
void _set_bit(const Tensor& tensor, bool value) override {
return tensor._set_conj(value);
}
Tensor resolve_bit(const Tensor& tensor) override {
return at::resolve_conj(tensor);
}
Tensor& math_op_(Tensor& tensor) override {
return at::conj_physical_(tensor);
}
};

void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
Expand Down Expand Up @@ -60,4 +51,5 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
}

}
} // namespace at
14 changes: 14 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,20 @@ bool NoTF32Guard::should_disable_tf32() {
return override_allow_tf32_flag;
}

thread_local bool BackwardPassGuard::is_backward_pass_;

BackwardPassGuard::BackwardPassGuard() {
is_backward_pass_ = true;
}

BackwardPassGuard::~BackwardPassGuard() {
is_backward_pass_ = false;
}

bool BackwardPassGuard::is_backward_pass() {
return is_backward_pass_;
}

bool Context::areVmapFallbackWarningsEnabled() const {
return display_vmap_fallback_warnings_;
}
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,12 @@ struct TORCH_API NoTF32Guard {
bool changed = false;
};

struct TORCH_API BackwardPassGuard {
BackwardPassGuard();
~BackwardPassGuard();
static bool is_backward_pass();
private:
static thread_local bool is_backward_pass_;
};

} // namespace at
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) {
int64_t numel = this->numel();
if (numel == 0) {
return;
} else if (numel < grain_size || at::get_num_threads() == 1) {
} else if (numel < internal::GRAIN_SIZE || at::get_num_threads() == 1) {
return serial_for_each(loop, {0, numel});
} else {
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
Expand Down
Loading