From 18e298305e0f1b0a3148086b8d697825ef8c5817 Mon Sep 17 00:00:00 2001 From: Adam Lerer Date: Tue, 7 Aug 2018 08:12:43 -0700 Subject: [PATCH 1/9] Increase TCP listen queue size from 64 to 1024 (#10268) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10268 Running torch.distributed.init_process_group fails with more than ~64 processes, with various errors like connection refused or connection reset by peer. After some digging, it looks like the root cause is that all workers have to connect to master via TCP (both in Zeus init and in DataChannelTCP - look for `connect()`), and the listening socket only has a backlog of 64. I increased the backlog to 1024, that seems like enough for reasonable purposes (the hard limit is 65535 in /proc/sys/net/core/somaxconn). There's probably a more correct way to do this that involves retries when connection is refused. Reviewed By: soumith Differential Revision: D9182216 fbshipit-source-id: 2f71c4995841db26c670cec344f1e3c7a80a7936 --- torch/lib/THD/base/ChannelUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/lib/THD/base/ChannelUtils.cpp b/torch/lib/THD/base/ChannelUtils.cpp index 971282f7db019..0c5951d8f48f4 100644 --- a/torch/lib/THD/base/ChannelUtils.cpp +++ b/torch/lib/THD/base/ChannelUtils.cpp @@ -16,7 +16,7 @@ namespace thd { namespace { -constexpr int LISTEN_QUEUE_SIZE = 64; +constexpr int LISTEN_QUEUE_SIZE = 1024; void setSocketNoDelay(int socket) { int flag = 1; From 66f7b8abbec547303114a0e755e7f46800bfd839 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 7 Aug 2018 08:50:18 -0700 Subject: [PATCH 2/9] Better macro name hygiene prefixing. (#10274) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10274 Good C++ libraries don't take up un-namespaced identifiers like DISABLE_COPY_AND_ASSIGN. Re-prefix this. Follow up fix: codemod Caffe2 to use the new macro, delete the forwarding definition Reviewed By: mingzhe09088 Differential Revision: D9181939 fbshipit-source-id: 857d099de1c2c0c4d0c1768c1ab772d59e28977c --- aten/src/ATen/core/Macros.h | 6 ++---- aten/src/ATen/core/TensorTypeIdRegistration.h | 8 ++++---- caffe2/core/common.h | 4 ++++ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/core/Macros.h b/aten/src/ATen/core/Macros.h index dcad67ddb68c8..4b7f094f805d0 100644 --- a/aten/src/ATen/core/Macros.h +++ b/aten/src/ATen/core/Macros.h @@ -23,8 +23,6 @@ // Disable the copy and assignment operator for a class. Note that this will // disable the usage of the class in std containers. -#ifndef DISABLE_COPY_AND_ASSIGN -#define DISABLE_COPY_AND_ASSIGN(classname) \ - classname(const classname&) = delete; \ +#define AT_DISABLE_COPY_AND_ASSIGN(classname) \ + classname(const classname&) = delete; \ classname& operator=(const classname&) = delete -#endif diff --git a/aten/src/ATen/core/TensorTypeIdRegistration.h b/aten/src/ATen/core/TensorTypeIdRegistration.h index a890c7990c4a4..0286115fdc66a 100644 --- a/aten/src/ATen/core/TensorTypeIdRegistration.h +++ b/aten/src/ATen/core/TensorTypeIdRegistration.h @@ -32,7 +32,7 @@ class TensorTypeIdCreator final { static constexpr at::TensorTypeId max_id_ = TensorTypeId( std::numeric_limits::max()); - DISABLE_COPY_AND_ASSIGN(TensorTypeIdCreator); + AT_DISABLE_COPY_AND_ASSIGN(TensorTypeIdCreator); }; class TensorTypeIdRegistry final { @@ -46,7 +46,7 @@ class TensorTypeIdRegistry final { std::unordered_set registeredTypeIds_; std::mutex mutex_; - DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistry); + AT_DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistry); }; class TensorTypeIds final { @@ -64,7 +64,7 @@ class TensorTypeIds final { TensorTypeIdCreator creator_; TensorTypeIdRegistry registry_; - DISABLE_COPY_AND_ASSIGN(TensorTypeIds); + AT_DISABLE_COPY_AND_ASSIGN(TensorTypeIds); }; inline constexpr at::TensorTypeId TensorTypeIds::undefined() noexcept { @@ -81,7 +81,7 @@ class TensorTypeIdRegistrar final { private: at::TensorTypeId id_; - DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistrar); + AT_DISABLE_COPY_AND_ASSIGN(TensorTypeIdRegistrar); }; inline at::TensorTypeId TensorTypeIdRegistrar::id() const noexcept { diff --git a/caffe2/core/common.h b/caffe2/core/common.h index 7d002028b14f3..1ab86f8e3d0cc 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -64,6 +64,10 @@ using std::vector; #define CAFFE2_USED __attribute__((__used__)) #endif //_MSC_VER +#ifndef DISABLE_COPY_AND_ASSIGN +#define DISABLE_COPY_AND_ASSIGN(classname) AT_DISABLE_COPY_AND_ASSIGN(classname) +#endif + // Define enabled when building for iOS or Android devices #if !defined(CAFFE2_MOBILE) #if defined(__ANDROID__) From ad76fc88073d6c252308376a6a6b986dbab010b3 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 7 Aug 2018 08:50:19 -0700 Subject: [PATCH 3/9] s/DISABLE_COPY_AND_ASSIGN/AT_DISABLE_COPY_AND_ASSIGN/ (#10275) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10275 Remove forwarding declaration in caffe2/core/common.h ``` codemod -d caffe2 --extensions cc,cpp,cu,cuh,h \\bDISABLE_COPY_AND_ASSIGN AT_DISABLE_COPY_AND_ASSIGN ``` Reviewed By: mingzhe09088 Differential Revision: D9184809 fbshipit-source-id: 958cf5162b0d92b83ea9c2597abb77320ca57ce8 --- caffe2/contrib/nccl/cuda_nccl_gpu.cc | 2 +- caffe2/core/blob.h | 2 +- caffe2/core/common.h | 4 ---- caffe2/core/common_cudnn.h | 4 ++-- caffe2/core/cudnn_wrappers.h | 4 ++-- caffe2/core/db.cc | 2 +- caffe2/core/db.h | 8 ++++---- caffe2/core/dispatch/KernelRegistration.h | 2 +- caffe2/core/hip/common_miopen.h | 2 +- caffe2/core/hip/miopen_wrapper.h | 4 ++-- caffe2/core/hip/net_async_dag_hip.cc | 2 +- caffe2/core/net.h | 2 +- caffe2/core/net_async_base.h | 2 +- caffe2/core/net_async_dag_gpu.cc | 4 ++-- caffe2/core/net_async_dag_gpu.h | 2 +- caffe2/core/net_async_polling.h | 2 +- caffe2/core/net_async_scheduling.h | 2 +- caffe2/core/net_dag.h | 2 +- caffe2/core/net_simple.h | 2 +- caffe2/core/net_simple_async.h | 2 +- caffe2/core/operator.h | 2 +- caffe2/core/registry.h | 2 +- caffe2/core/timer.h | 2 +- caffe2/core/workspace.h | 2 +- caffe2/db/create_db_op.h | 2 +- caffe2/db/leveldb.cc | 2 +- caffe2/db/lmdb.cc | 2 +- caffe2/db/protodb.cc | 2 +- caffe2/mkl/utils/mkl_memory.h | 8 ++++---- caffe2/mobile/contrib/arm-compute/core/net_gl.h | 2 +- caffe2/operators/expand_squeeze_dims_op.h | 2 +- caffe2/operators/partition_ops.h | 4 ++-- caffe2/operators/slice_op.h | 4 ++-- caffe2/queue/blobs_queue_db.cc | 2 +- caffe2/utils/threadpool/WorkersPool.h | 2 +- caffe2/utils/zmq_helper.h | 4 ++-- 36 files changed, 48 insertions(+), 52 deletions(-) diff --git a/caffe2/contrib/nccl/cuda_nccl_gpu.cc b/caffe2/contrib/nccl/cuda_nccl_gpu.cc index aa321e5589d9b..59a796a07a37f 100644 --- a/caffe2/contrib/nccl/cuda_nccl_gpu.cc +++ b/caffe2/contrib/nccl/cuda_nccl_gpu.cc @@ -72,7 +72,7 @@ class NCCLContext { cudaEvent_t master_event_; std::vector events_; - DISABLE_COPY_AND_ASSIGN(NCCLContext); + AT_DISABLE_COPY_AND_ASSIGN(NCCLContext); }; // We share the contexts across multiple operators, hence the diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h index b2e2cce917cd6..c8b8c29d570a9 100644 --- a/caffe2/core/blob.h +++ b/caffe2/core/blob.h @@ -288,7 +288,7 @@ class Blob { void* pointer_ = nullptr; DestroyCall destroy_ = nullptr; - DISABLE_COPY_AND_ASSIGN(Blob); + AT_DISABLE_COPY_AND_ASSIGN(Blob); }; inline void swap(Blob& lhs, Blob& rhs) { diff --git a/caffe2/core/common.h b/caffe2/core/common.h index 1ab86f8e3d0cc..7d002028b14f3 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -64,10 +64,6 @@ using std::vector; #define CAFFE2_USED __attribute__((__used__)) #endif //_MSC_VER -#ifndef DISABLE_COPY_AND_ASSIGN -#define DISABLE_COPY_AND_ASSIGN(classname) AT_DISABLE_COPY_AND_ASSIGN(classname) -#endif - // Define enabled when building for iOS or Android devices #if !defined(CAFFE2_MOBILE) #if defined(__ANDROID__) diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h index ca154a0e65b76..fe54318133ac9 100644 --- a/caffe2/core/common_cudnn.h +++ b/caffe2/core/common_cudnn.h @@ -259,7 +259,7 @@ class cudnnTensorDescWrapper { cudnnTensorFormat_t format_; cudnnDataType_t type_; vector dims_; - DISABLE_COPY_AND_ASSIGN(cudnnTensorDescWrapper); + AT_DISABLE_COPY_AND_ASSIGN(cudnnTensorDescWrapper); }; class cudnnFilterDescWrapper { @@ -313,7 +313,7 @@ class cudnnFilterDescWrapper { StorageOrder order_; cudnnDataType_t type_; vector dims_; - DISABLE_COPY_AND_ASSIGN(cudnnFilterDescWrapper); + AT_DISABLE_COPY_AND_ASSIGN(cudnnFilterDescWrapper); }; diff --git a/caffe2/core/cudnn_wrappers.h b/caffe2/core/cudnn_wrappers.h index c2910e2e65840..b518914e50402 100644 --- a/caffe2/core/cudnn_wrappers.h +++ b/caffe2/core/cudnn_wrappers.h @@ -89,7 +89,7 @@ class CuDNNState { cudaStream_t stream_{nullptr}; CuDNNWorkspace workspace_; size_t gpu_id_{0}; - DISABLE_COPY_AND_ASSIGN(CuDNNState); + AT_DISABLE_COPY_AND_ASSIGN(CuDNNState); }; /** @@ -153,7 +153,7 @@ class CuDNNWrapper { CAFFE2_COMPILE_TIME_MAX_GPUS>; static PerGPUCuDNNStates& cudnn_states(); - DISABLE_COPY_AND_ASSIGN(CuDNNWrapper); + AT_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper); }; }; // namespace caffe2 diff --git a/caffe2/core/db.cc b/caffe2/core/db.cc index 3dd993c925a2d..386787b51c353 100644 --- a/caffe2/core/db.cc +++ b/caffe2/core/db.cc @@ -119,7 +119,7 @@ class MiniDBTransaction : public Transaction { FILE* file_; std::lock_guard lock_; - DISABLE_COPY_AND_ASSIGN(MiniDBTransaction); + AT_DISABLE_COPY_AND_ASSIGN(MiniDBTransaction); }; class MiniDB : public DB { diff --git a/caffe2/core/db.h b/caffe2/core/db.h index 7c5b79df69191..13b29664dac29 100644 --- a/caffe2/core/db.h +++ b/caffe2/core/db.h @@ -52,7 +52,7 @@ class Cursor { */ virtual bool Valid() = 0; - DISABLE_COPY_AND_ASSIGN(Cursor); + AT_DISABLE_COPY_AND_ASSIGN(Cursor); }; /** @@ -71,7 +71,7 @@ class Transaction { */ virtual void Commit() = 0; - DISABLE_COPY_AND_ASSIGN(Transaction); + AT_DISABLE_COPY_AND_ASSIGN(Transaction); }; /** @@ -99,7 +99,7 @@ class DB { protected: Mode mode_; - DISABLE_COPY_AND_ASSIGN(DB); + AT_DISABLE_COPY_AND_ASSIGN(DB); }; // Database classes are registered by their names so we can do optional @@ -285,7 +285,7 @@ class DBReader { uint32_t num_shards_; uint32_t shard_id_; - DISABLE_COPY_AND_ASSIGN(DBReader); + AT_DISABLE_COPY_AND_ASSIGN(DBReader); }; class DBReaderSerializer : public BlobSerializerBase { diff --git a/caffe2/core/dispatch/KernelRegistration.h b/caffe2/core/dispatch/KernelRegistration.h index 9f7f9d194bbb3..9ebc20b7ab0a6 100644 --- a/caffe2/core/dispatch/KernelRegistration.h +++ b/caffe2/core/dispatch/KernelRegistration.h @@ -57,7 +57,7 @@ class KernelRegistrar final { const typename Schema::dispatch::dispatch_key_type dispatch_key_; bool owns_registration_; - DISABLE_COPY_AND_ASSIGN(KernelRegistrar); + AT_DISABLE_COPY_AND_ASSIGN(KernelRegistrar); }; /** diff --git a/caffe2/core/hip/common_miopen.h b/caffe2/core/hip/common_miopen.h index 290ae99b45171..aa9333f8bdfa1 100644 --- a/caffe2/core/hip/common_miopen.h +++ b/caffe2/core/hip/common_miopen.h @@ -164,7 +164,7 @@ class miopenTensorDescWrapper miopenTensorDescriptor_t desc_; miopenDataType_t type_; vector dims_; - DISABLE_COPY_AND_ASSIGN(miopenTensorDescWrapper); + AT_DISABLE_COPY_AND_ASSIGN(miopenTensorDescWrapper); }; } // namespace caffe2 diff --git a/caffe2/core/hip/miopen_wrapper.h b/caffe2/core/hip/miopen_wrapper.h index 2671d4b2a698a..910db8b79d788 100644 --- a/caffe2/core/hip/miopen_wrapper.h +++ b/caffe2/core/hip/miopen_wrapper.h @@ -92,7 +92,7 @@ class MIOPENState hipStream_t stream_{nullptr}; MIOPENWorkspace workspace_; size_t gpu_id_{0}; - DISABLE_COPY_AND_ASSIGN(MIOPENState); + AT_DISABLE_COPY_AND_ASSIGN(MIOPENState); }; /** @@ -157,7 +157,7 @@ class MIOPENWrapper CAFFE2_COMPILE_TIME_MAX_HIP_GPUS>; static PerGPUMIOPENStates& miopen_states(); - DISABLE_COPY_AND_ASSIGN(MIOPENWrapper); + AT_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper); }; }; // namespace caffe2 diff --git a/caffe2/core/hip/net_async_dag_hip.cc b/caffe2/core/hip/net_async_dag_hip.cc index 439501af3ae78..7d10b29e965d4 100644 --- a/caffe2/core/hip/net_async_dag_hip.cc +++ b/caffe2/core/hip/net_async_dag_hip.cc @@ -58,7 +58,7 @@ class ProfiledRange ProfiledRange(const OperatorDef& def, Color color) {} private: - DISABLE_COPY_AND_ASSIGN(ProfiledRange); + AT_DISABLE_COPY_AND_ASSIGN(ProfiledRange); }; } // namespace diff --git a/caffe2/core/net.h b/caffe2/core/net.h index f90028654902f..e901d17e27907 100644 --- a/caffe2/core/net.h +++ b/caffe2/core/net.h @@ -124,7 +124,7 @@ class NetBase : public Observable { string name_; vector events_; std::shared_ptr net_def_; - DISABLE_COPY_AND_ASSIGN(NetBase); + AT_DISABLE_COPY_AND_ASSIGN(NetBase); }; class ExecutorHelper { diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h index c4425ff95093a..09510fdb16ad0 100644 --- a/caffe2/core/net_async_base.h +++ b/caffe2/core/net_async_base.h @@ -125,7 +125,7 @@ class AsyncNetBase : public NetBase { bool use_per_net_pools_; bool is_blocking_; - DISABLE_COPY_AND_ASSIGN(AsyncNetBase); + AT_DISABLE_COPY_AND_ASSIGN(AsyncNetBase); private: void storeExceptionPtr(); diff --git a/caffe2/core/net_async_dag_gpu.cc b/caffe2/core/net_async_dag_gpu.cc index 12bd33ac7e247..867def700863f 100644 --- a/caffe2/core/net_async_dag_gpu.cc +++ b/caffe2/core/net_async_dag_gpu.cc @@ -71,7 +71,7 @@ class ProfiledRange { private: nvtxRangeId_t range_ = 0; - DISABLE_COPY_AND_ASSIGN(ProfiledRange); + AT_DISABLE_COPY_AND_ASSIGN(ProfiledRange); }; #else @@ -81,7 +81,7 @@ class ProfiledRange { ProfiledRange(const OperatorDef& def, Color color) {} private: - DISABLE_COPY_AND_ASSIGN(ProfiledRange); + AT_DISABLE_COPY_AND_ASSIGN(ProfiledRange); }; #endif // ifdef CAFFE2_USE_NVTX diff --git a/caffe2/core/net_async_dag_gpu.h b/caffe2/core/net_async_dag_gpu.h index f447c6bfe8760..8dcd812a1fc8c 100644 --- a/caffe2/core/net_async_dag_gpu.h +++ b/caffe2/core/net_async_dag_gpu.h @@ -32,7 +32,7 @@ class AsyncDAGNet : public DAGNetBase { int stream(const DeviceOption& device_option); static thread_local std::vector stream_counters_; - DISABLE_COPY_AND_ASSIGN(AsyncDAGNet); + AT_DISABLE_COPY_AND_ASSIGN(AsyncDAGNet); }; } // namespace caffe2 diff --git a/caffe2/core/net_async_polling.h b/caffe2/core/net_async_polling.h index dc807bb04b0cc..8b3d6db8d695e 100644 --- a/caffe2/core/net_async_polling.h +++ b/caffe2/core/net_async_polling.h @@ -40,7 +40,7 @@ class AsyncPollingNet : public AsyncNetBase { void reset() override; std::atomic has_chain_failed_; - DISABLE_COPY_AND_ASSIGN(AsyncPollingNet); + AT_DISABLE_COPY_AND_ASSIGN(AsyncPollingNet); }; } // namespace caffe2 diff --git a/caffe2/core/net_async_scheduling.h b/caffe2/core/net_async_scheduling.h index 363872d13ac46..096e7e2b2362a 100644 --- a/caffe2/core/net_async_scheduling.h +++ b/caffe2/core/net_async_scheduling.h @@ -30,7 +30,7 @@ class AsyncSchedulingNet : public AsyncNetBase { std::atomic processed_tasks_num_; - DISABLE_COPY_AND_ASSIGN(AsyncSchedulingNet); + AT_DISABLE_COPY_AND_ASSIGN(AsyncSchedulingNet); }; } // namespace caffe2 diff --git a/caffe2/core/net_dag.h b/caffe2/core/net_dag.h index d941f73e8f0de..5a9996e08819c 100644 --- a/caffe2/core/net_dag.h +++ b/caffe2/core/net_dag.h @@ -84,7 +84,7 @@ class DAGNetBase : public NetBase { mutable std::vector stats_; std::unordered_map> task_timers_; - DISABLE_COPY_AND_ASSIGN(DAGNetBase); + AT_DISABLE_COPY_AND_ASSIGN(DAGNetBase); }; class DAGNet : public DAGNetBase { diff --git a/caffe2/core/net_simple.h b/caffe2/core/net_simple.h index e741a39638825..99060ddb0bcaf 100644 --- a/caffe2/core/net_simple.h +++ b/caffe2/core/net_simple.h @@ -48,7 +48,7 @@ class SimpleNet : public NetBase { vector> operators_; - DISABLE_COPY_AND_ASSIGN(SimpleNet); + AT_DISABLE_COPY_AND_ASSIGN(SimpleNet); }; } // namespace caffe2 diff --git a/caffe2/core/net_simple_async.h b/caffe2/core/net_simple_async.h index cf2a3d4c2a469..b29ae217cdaeb 100644 --- a/caffe2/core/net_simple_async.h +++ b/caffe2/core/net_simple_async.h @@ -43,7 +43,7 @@ class AsyncSimpleNet : public NetBase { vector> operators_; - DISABLE_COPY_AND_ASSIGN(AsyncSimpleNet); + AT_DISABLE_COPY_AND_ASSIGN(AsyncSimpleNet); }; } // namespace caffe2 diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index e6ac302e47fb0..048207b64d75d 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -408,7 +408,7 @@ class OperatorBase : public Observable { // An event used by asynchronous execution. std::unique_ptr event_; - DISABLE_COPY_AND_ASSIGN(OperatorBase); + AT_DISABLE_COPY_AND_ASSIGN(OperatorBase); }; // If your operator does not need any specialized contructor or destructor, diff --git a/caffe2/core/registry.h b/caffe2/core/registry.h index 0c8cdb852f188..f5e0932228a97 100644 --- a/caffe2/core/registry.h +++ b/caffe2/core/registry.h @@ -108,7 +108,7 @@ class Registry { CaffeMap help_message_; std::mutex register_mutex_; - DISABLE_COPY_AND_ASSIGN(Registry); + AT_DISABLE_COPY_AND_ASSIGN(Registry); }; template diff --git a/caffe2/core/timer.h b/caffe2/core/timer.h index 150aabe185ba2..a290ffc4aadc1 100644 --- a/caffe2/core/timer.h +++ b/caffe2/core/timer.h @@ -41,7 +41,7 @@ class Timer { protected: std::chrono::time_point start_time_; - DISABLE_COPY_AND_ASSIGN(Timer); + AT_DISABLE_COPY_AND_ASSIGN(Timer); }; } diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h index 4a759b8703dc4..5f04309855fdf 100644 --- a/caffe2/core/workspace.h +++ b/caffe2/core/workspace.h @@ -297,7 +297,7 @@ class Workspace { std::unique_ptr thread_pool_; std::mutex thread_pool_creation_mutex_; - DISABLE_COPY_AND_ASSIGN(Workspace); + AT_DISABLE_COPY_AND_ASSIGN(Workspace); }; } // namespace caffe2 diff --git a/caffe2/db/create_db_op.h b/caffe2/db/create_db_op.h index c2d6060b4f03f..ac7c137cea9aa 100644 --- a/caffe2/db/create_db_op.h +++ b/caffe2/db/create_db_op.h @@ -34,7 +34,7 @@ class CreateDBOp final : public Operator { string db_name_; uint32_t num_shards_; uint32_t shard_id_; - DISABLE_COPY_AND_ASSIGN(CreateDBOp); + AT_DISABLE_COPY_AND_ASSIGN(CreateDBOp); }; } // namespace caffe2 diff --git a/caffe2/db/leveldb.cc b/caffe2/db/leveldb.cc index 6c5eff44fa925..23a188027ece7 100644 --- a/caffe2/db/leveldb.cc +++ b/caffe2/db/leveldb.cc @@ -51,7 +51,7 @@ class LevelDBTransaction : public Transaction { leveldb::DB* db_; std::unique_ptr batch_; - DISABLE_COPY_AND_ASSIGN(LevelDBTransaction); + AT_DISABLE_COPY_AND_ASSIGN(LevelDBTransaction); }; class LevelDB : public DB { diff --git a/caffe2/db/lmdb.cc b/caffe2/db/lmdb.cc index 0af3af0834dc7..2eb65bb7aa738 100644 --- a/caffe2/db/lmdb.cc +++ b/caffe2/db/lmdb.cc @@ -114,7 +114,7 @@ class LMDBTransaction final : public Transaction { MDB_dbi mdb_dbi_; MDB_txn* mdb_txn_; - DISABLE_COPY_AND_ASSIGN(LMDBTransaction); + AT_DISABLE_COPY_AND_ASSIGN(LMDBTransaction); }; class LMDB : public DB { diff --git a/caffe2/db/protodb.cc b/caffe2/db/protodb.cc index 64d5e952f2e4d..2473ad23b6c45 100644 --- a/caffe2/db/protodb.cc +++ b/caffe2/db/protodb.cc @@ -60,7 +60,7 @@ class ProtoDBTransaction : public Transaction { TensorProtos* proto_; std::unordered_set existing_names_; - DISABLE_COPY_AND_ASSIGN(ProtoDBTransaction); + AT_DISABLE_COPY_AND_ASSIGN(ProtoDBTransaction); }; class ProtoDB : public DB { diff --git a/caffe2/mkl/utils/mkl_memory.h b/caffe2/mkl/utils/mkl_memory.h index ac74e8ae070b5..9d9e91a565eb0 100644 --- a/caffe2/mkl/utils/mkl_memory.h +++ b/caffe2/mkl/utils/mkl_memory.h @@ -58,7 +58,7 @@ class PrimitiveWrapper { private: dnnPrimitive_t primitive_ = 0; - DISABLE_COPY_AND_ASSIGN(PrimitiveWrapper); + AT_DISABLE_COPY_AND_ASSIGN(PrimitiveWrapper); }; template @@ -138,7 +138,7 @@ class LayoutWrapper { private: dnnLayout_t layout_ = 0; - DISABLE_COPY_AND_ASSIGN(LayoutWrapper); + AT_DISABLE_COPY_AND_ASSIGN(LayoutWrapper); }; /** @@ -557,7 +557,7 @@ class MKLMemory { // The primitive to use to convert from internal layout to user layout PrimitiveWrapper convert_out_; - DISABLE_COPY_AND_ASSIGN(MKLMemory); + AT_DISABLE_COPY_AND_ASSIGN(MKLMemory); }; template @@ -575,7 +575,7 @@ class MKLWorkspace { private: void* buffer_; - DISABLE_COPY_AND_ASSIGN(MKLWorkspace); + AT_DISABLE_COPY_AND_ASSIGN(MKLWorkspace); }; } // namespace mkl diff --git a/caffe2/mobile/contrib/arm-compute/core/net_gl.h b/caffe2/mobile/contrib/arm-compute/core/net_gl.h index dc8643b8e0191..3b83d3120a56c 100644 --- a/caffe2/mobile/contrib/arm-compute/core/net_gl.h +++ b/caffe2/mobile/contrib/arm-compute/core/net_gl.h @@ -57,7 +57,7 @@ class GLNet : public NetBase { vector> operators_; - DISABLE_COPY_AND_ASSIGN(GLNet); + AT_DISABLE_COPY_AND_ASSIGN(GLNet); }; } // namespace caffe2 diff --git a/caffe2/operators/expand_squeeze_dims_op.h b/caffe2/operators/expand_squeeze_dims_op.h index ef025edd3ddec..69b3307e0cf78 100644 --- a/caffe2/operators/expand_squeeze_dims_op.h +++ b/caffe2/operators/expand_squeeze_dims_op.h @@ -112,7 +112,7 @@ class SqueezeOp : public Operator { vector dims_; public: - DISABLE_COPY_AND_ASSIGN(SqueezeOp); + AT_DISABLE_COPY_AND_ASSIGN(SqueezeOp); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_ diff --git a/caffe2/operators/partition_ops.h b/caffe2/operators/partition_ops.h index 003653cbc8976..35cf83811fecc 100644 --- a/caffe2/operators/partition_ops.h +++ b/caffe2/operators/partition_ops.h @@ -221,7 +221,7 @@ class PartitionOp : public PartitionOpBase { return true; } - DISABLE_COPY_AND_ASSIGN(PartitionOp); + AT_DISABLE_COPY_AND_ASSIGN(PartitionOp); }; class LengthsPartitionOp : public PartitionOpBase { @@ -287,7 +287,7 @@ class LengthsPartitionOp : public PartitionOpBase { return true; } - DISABLE_COPY_AND_ASSIGN(LengthsPartitionOp); + AT_DISABLE_COPY_AND_ASSIGN(LengthsPartitionOp); vector out_length_; }; diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index ee591e4b0157c..f6f15d10bc1ec 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -245,7 +245,7 @@ class SliceOp : public Operator { output, data, starts_host_, ends_host_, &context_); } - DISABLE_COPY_AND_ASSIGN(SliceOp); + AT_DISABLE_COPY_AND_ASSIGN(SliceOp); private: std::vector starts_; @@ -304,7 +304,7 @@ class SliceGradientOp : public Operator { } } - DISABLE_COPY_AND_ASSIGN(SliceGradientOp); + AT_DISABLE_COPY_AND_ASSIGN(SliceGradientOp); private: std::vector starts_; diff --git a/caffe2/queue/blobs_queue_db.cc b/caffe2/queue/blobs_queue_db.cc index ef06be9f3fd14..06a6985848ce2 100644 --- a/caffe2/queue/blobs_queue_db.cc +++ b/caffe2/queue/blobs_queue_db.cc @@ -32,7 +32,7 @@ class CreateBlobsQueueDBOp : public Operator { } private: - DISABLE_COPY_AND_ASSIGN(CreateBlobsQueueDBOp); + AT_DISABLE_COPY_AND_ASSIGN(CreateBlobsQueueDBOp); }; REGISTER_CPU_OPERATOR(CreateBlobsQueueDB, CreateBlobsQueueDBOp); diff --git a/caffe2/utils/threadpool/WorkersPool.h b/caffe2/utils/threadpool/WorkersPool.h index 0c621d53854de..27b75d8ccd3a6 100644 --- a/caffe2/utils/threadpool/WorkersPool.h +++ b/caffe2/utils/threadpool/WorkersPool.h @@ -360,7 +360,7 @@ class WorkersPool { counter_to_decrement_when_ready_.Wait(); } - DISABLE_COPY_AND_ASSIGN(WorkersPool); + AT_DISABLE_COPY_AND_ASSIGN(WorkersPool); std::vector>> workers_; // The BlockingCounter used to wait for the workers. BlockingCounter counter_to_decrement_when_ready_; diff --git a/caffe2/utils/zmq_helper.h b/caffe2/utils/zmq_helper.h index be03d98b30364..cfd1d53a98af6 100644 --- a/caffe2/utils/zmq_helper.h +++ b/caffe2/utils/zmq_helper.h @@ -26,7 +26,7 @@ class ZmqContext { private: void* ptr_; - DISABLE_COPY_AND_ASSIGN(ZmqContext); + AT_DISABLE_COPY_AND_ASSIGN(ZmqContext); }; class ZmqMessage { @@ -48,7 +48,7 @@ class ZmqMessage { private: zmq_msg_t msg_; - DISABLE_COPY_AND_ASSIGN(ZmqMessage); + AT_DISABLE_COPY_AND_ASSIGN(ZmqMessage); }; class ZmqSocket { From fe68879832f1198c0e2edb24d6b7b415ab9c87ae Mon Sep 17 00:00:00 2001 From: Roy Li Date: Tue, 7 Aug 2018 09:51:54 -0700 Subject: [PATCH 4/9] Fix dir(torch) for python 3.7 (#10271) Summary: fixes #10160. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10271 Differential Revision: D9188031 Pulled By: li-roy fbshipit-source-id: a3620553a8ba2b7391acdf78dbe58afcdb6c5f7f --- test/test_torch.py | 3 +++ torch/__init__.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/test/test_torch.py b/test/test_torch.py index 8bcd30a7e36a9..e494981abf031 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -143,6 +143,9 @@ def make_contiguous_slice(size, dtype): return tensors + def test_dir(self): + dir(torch) + def test_dot(self): types = { 'torch.DoubleTensor': 1e-8, diff --git a/torch/__init__.py b/torch/__init__.py index 043ca118e7301..e494cdec6cbec 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -251,6 +251,8 @@ def manager_path(): del manager_path for name in dir(_C._VariableFunctions): + if name in ["__dir__", "__doc__"]: + continue globals()[name] = getattr(_C._VariableFunctions, name) ################################################################################ From 20a549b1018a05ce6832702b105908ba7bab742d Mon Sep 17 00:00:00 2001 From: Jorghi12 Date: Tue, 7 Aug 2018 11:06:00 -0700 Subject: [PATCH 5/9] Start using a newer version of rocRand that's PyTorch compatible. Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10280 Differential Revision: D9196349 Pulled By: Jorghi12 fbshipit-source-id: 4147f2e6e3fdd641b026f3761d684437591405be --- docker/caffe2/jenkins/common/install_rocm.sh | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/docker/caffe2/jenkins/common/install_rocm.sh b/docker/caffe2/jenkins/common/install_rocm.sh index 9b69a917c3c2c..f76cf90f92657 100644 --- a/docker/caffe2/jenkins/common/install_rocm.sh +++ b/docker/caffe2/jenkins/common/install_rocm.sh @@ -21,8 +21,6 @@ install_ubuntu() { miopengemm \ rocblas \ hipblas \ - rocrand \ - hcsparse \ rocm-profiler \ cxlactivitylogger @@ -65,6 +63,20 @@ install_hcrng() { dpkg -i /opt/rocm/debians/hcrng.deb } +# This will be removed after merging an upcoming PR. +install_hcsparse() { + mkdir -p /opt/rocm/debians + curl https://s3.amazonaws.com/ossci-linux/hcsparse-master-907a505-Linux.deb -o /opt/rocm/debians/hcsparse.deb + dpkg -i /opt/rocm/debians/hcsparse.deb +} + +# Install an updated version of rocRand that's PyTorch compatible. +install_rocrand() { + mkdir -p /opt/rocm/debians + curl https://s3.amazonaws.com/ossci-linux/rocrand-1.8.0-Linux.deb -o /opt/rocm/debians/rocrand.deb + dpkg -i /opt/rocm/debians/rocrand.deb +} + # Install Python packages depending on the base OS if [ -f /etc/lsb-release ]; then install_ubuntu @@ -77,3 +89,5 @@ fi install_hip_thrust install_hcrng +install_rocrand +install_hcsparse From dcaafdd04b4c728d874402fa909d5dfdd47b7f0c Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Tue, 7 Aug 2018 11:15:02 -0700 Subject: [PATCH 6/9] fix doc of sparse_coo_tensor (#10308) Summary: - fixes #9998 Pull Request resolved: https://github.com/pytorch/pytorch/pull/10308 Differential Revision: D9196423 Pulled By: weiyangfb fbshipit-source-id: 23b4ed96e354ac9aa7c268aad105818a2c6d3bd8 --- torch/_torch_docs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 91cfaff303c0b..9b2031544651b 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4136,8 +4136,10 @@ def parse_kwargs(desc): r""" sparse_coo_tensor(indices, values, size=None, dtype=None, device=None, requires_grad=False) -> Tensor -Constructs a sparse_coo_tensor with non-zero elements at the given :attr:`indices` with the given -:attr:`values`. +Constructs a sparse tensors in COO(rdinate) format with non-zero elements at the given :attr:`indices` +with the given :attr:`values`. A sparse tensor can be `uncoalesced`, in that case, there are duplicate +coordinates in the indices, and the value at that index is the sum of all duplicate value entries: +`torch.spaerse`_. Args: indices (array_like): Initial data for the tensor. Can be a list, tuple, @@ -4192,6 +4194,8 @@ def parse_kwargs(desc): tensor([], dtype=torch.int64) and values: tensor([]) + +.. _torch.sparse: https://pytorch.org/docs/stable/sparse.html """) add_docstr(torch.sqrt, From db7a2b1f0dade4d70e381cfd6a6f9887fd7ce067 Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Tue, 7 Aug 2018 11:19:54 -0700 Subject: [PATCH 7/9] fix doc for as_tensor (#10309) Summary: - fixes #9914 Pull Request resolved: https://github.com/pytorch/pytorch/pull/10309 Differential Revision: D9196427 Pulled By: weiyangfb fbshipit-source-id: c9a01e42c2e9dbfe2bd94ad14651d9f578751de2 --- torch/_torch_docs.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9b2031544651b..6243d56da70de 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -425,18 +425,21 @@ def parse_kwargs(desc): Example:: - >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) - tensor([[ 0.1000, 1.2000], - [ 2.2000, 3.1000], - [ 4.9000, 5.2000]]) - >>> a = numpy.array([1, 2, 3]) - >>> t = torch.from_numpy(a) + >>> t = torch.as_tensor(a) >>> t tensor([ 1, 2, 3]) >>> t[0] = -1 >>> a array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) """.format(**factory_data_common_args)) add_docstr(torch.asin, From 2993c42ee44e9a592c03c45402bb04ff7d34fc2a Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 7 Aug 2018 12:14:51 -0700 Subject: [PATCH 8/9] Squash some 'invalid escape sequence' warnings. (#10310) Summary: Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/10310 Differential Revision: D9196254 Pulled By: ezyang fbshipit-source-id: 63bb8e52ac6970fe8e11a2d3c491ab58250dc467 --- aten/src/ATen/code_template.py | 4 ++-- aten/src/ATen/preprocess_declarations.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/code_template.py b/aten/src/ATen/code_template.py index 1cebf11839e7c..937269a50fed8 100644 --- a/aten/src/ATen/code_template.py +++ b/aten/src/ATen/code_template.py @@ -11,13 +11,13 @@ class CodeTemplate(object): - substitution_str = '(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})' + substitution_str = r'(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})' # older versions of Python have a bug where \w* does not work, # so we need to replace with the non-shortened version [a-zA-Z0-9_]* # https://bugs.python.org/issue18647 - substitution_str = substitution_str.replace('\w', '[a-zA-Z0-9_]') + substitution_str = substitution_str.replace(r'\w', r'[a-zA-Z0-9_]') subtitution = re.compile(substitution_str, re.MULTILINE) diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index 1bc33e533531d..173ac439487d2 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -124,7 +124,7 @@ def should_generate_out_variant(option): def sanitize_return(option): ret = option['return'] - m = re.match('argument (\d+(,\d+)*)', ret) + m = re.match(r'argument (\d+(,\d+)*)', ret) if m is not None: arguments = [int(x) for x in m.group(1).split(',')] option['return'] = {'kind': 'arguments', 'arguments': arguments} From 9b1a65bec3bbbf9c399fcc9fb9e20d6b40d50f95 Mon Sep 17 00:00:00 2001 From: mruberry Date: Tue, 7 Aug 2018 12:18:15 -0700 Subject: [PATCH 9/9] Extends type and shape tracing with device (#9796) Summary: This PR extends the existing type and shape metadata tracing and verification done in autograd with device information. This expansion of tracing is required for #8354, is likely useful in other scenarios, and is a healthy sanity check, just like type and shape tracing. The precise changes are: - TypeAndShape -> InputMetadata, now includes device() - Creating InputMetadata is simplified to just require a tensor, and callers were updated to use this simpler invocation wherever possible - The gradient accumulator of a variable is now reset when set_data() is called if either the type or device changes, and this reset now locks to avoid contention with acquiring the gradient accumulator - Mismatched devices during backward() will throw a runtime error, just like mismatched type and shape - (Bonus!) Two uninitialized pointers in THCReduce are now initialized (to nullptr) to prevent build warnings fyi colesbury Pull Request resolved: https://github.com/pytorch/pytorch/pull/9796 Reviewed By: goldsborough Differential Revision: D9119325 Pulled By: ezyang fbshipit-source-id: 76d1861b8d4f74db0575ff1f3bd965e18f9463de --- aten/src/THC/THCReduce.cuh | 6 +-- tools/autograd/templates/VariableType.cpp | 4 +- torch/csrc/autograd/engine.cpp | 7 +++ torch/csrc/autograd/function.h | 21 ++++++--- .../autograd/functions/accumulate_grad.cpp | 2 +- torch/csrc/autograd/functions/tensor.cpp | 2 +- torch/csrc/autograd/functions/utils.h | 2 +- torch/csrc/autograd/input_metadata.h | 44 +++++++++++++++++++ torch/csrc/autograd/python_function.cpp | 2 +- .../csrc/autograd/python_legacy_variable.cpp | 2 +- torch/csrc/autograd/type_and_shape.h | 33 -------------- torch/csrc/autograd/variable.cpp | 26 ++++++++--- 12 files changed, 95 insertions(+), 56 deletions(-) create mode 100644 torch/csrc/autograd/input_metadata.h diff --git a/aten/src/THC/THCReduce.cuh b/aten/src/THC/THCReduce.cuh index 1a72ae6ad5674..2ca972144505b 100644 --- a/aten/src/THC/THCReduce.cuh +++ b/aten/src/THC/THCReduce.cuh @@ -517,9 +517,9 @@ bool THC_reduceDim(THCState* state, (TYPE) outElements, init, modifyOp, reduceOp, finalizeOp); \ } \ else \ - { \ - void* stagingData; \ - void* semaphores; \ + { \ + void* stagingData = nullptr; \ + void* semaphores = nullptr; \ \ if(grid.y > 1) \ { \ diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index cd6bf900d1942..fd8a960b6f2fa 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -343,7 +343,7 @@ static void throw_error_out_requires_grad(const char* name) { static void rebase_history(Variable& var, std::shared_ptr grad_fn) { if (grad_fn && var.defined()) { - grad_fn->add_input_metadata(var.type(), var.sizes()); + grad_fn->add_input_metadata(var); var.rebase_history({std::move(grad_fn), 0}); } } @@ -353,7 +353,7 @@ static void rebase_history(ArrayRef vars, std::shared_ptr gr for (auto& var : vars) { if (var.defined()) { // TODO: eliminate const_cast - auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes()); + auto output_nr = grad_fn->add_input_metadata(var); const_cast(var).rebase_history({grad_fn, output_nr}); } else { grad_fn->add_input_metadata(Function::undefined_input()); diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 74e15f5caefe9..cb024a029620e 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -338,6 +338,13 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const ss << metadata.type() << " but got " << grads[i].type(); throw std::runtime_error(format_error(ss.str())); } + const auto output_device = output.is_cuda() ? output.get_device() : -1; + if (output_device != metadata.device()) { + std::stringstream ss; + ss << "invalid gradient at index " << i << " - expected device "; + ss << metadata.device() << " but got " << output_device; + throw std::runtime_error(format_error(ss.str())); + } } } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index bc8ffc0e8357d..b4c90b1489a26 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -5,7 +5,7 @@ #include "torch/csrc/autograd/anomaly_mode.h" #include "torch/csrc/autograd/profiler.h" #include "torch/csrc/autograd/saved_variable.h" -#include "torch/csrc/autograd/type_and_shape.h" +#include "torch/csrc/autograd/input_metadata.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/utils/python_stub.h" #include "torch/csrc/utils/variadic.h" @@ -128,9 +128,18 @@ struct TORCH_API Function : std::enable_shared_from_this { /// Adds the type and shape metadata for a new input. Returns the index of /// of the new input. - uint32_t add_input_metadata(const at::Type& type, at::IntList shape) noexcept { + uint32_t add_input_metadata( + const at::Type& type + , at::IntList shape + , const int64_t device) noexcept { uint32_t input_nr = input_metadata_.size(); - input_metadata_.emplace_back(type, shape); + input_metadata_.emplace_back(type, shape, device); + return input_nr; + } + + uint32_t add_input_metadata(const at::Tensor& t) noexcept { + uint32_t input_nr = input_metadata_.size(); + input_metadata_.emplace_back(t); return input_nr; } @@ -145,7 +154,7 @@ struct TORCH_API Function : std::enable_shared_from_this { return input_metadata_.size(); } - const TypeAndShape& input_metadata(size_t index) const { + const InputMetadata& input_metadata(size_t index) const { return input_metadata_[index]; } @@ -322,7 +331,7 @@ struct TORCH_API Function : std::enable_shared_from_this { std::unique_ptr anomaly_metadata_ = nullptr; std::vector> pre_hooks_; std::vector> post_hooks_; - at::SmallVector input_metadata_; + at::SmallVector input_metadata_; }; /// See Function::is_traceable() for definition. @@ -367,7 +376,7 @@ inline void create_gradient_edge( Variable& variable, std::shared_ptr function) { // Copy before move. - const auto input_nr = function->add_input_metadata(variable.type(), variable.sizes()); + const auto input_nr = function->add_input_metadata(variable); variable.set_gradient_edge({std::move(function), input_nr}); } diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index 391cf3697decf..fd24f6987642b 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -19,7 +19,7 @@ namespace torch { namespace autograd { AccumulateGrad::AccumulateGrad(Variable variable_) : Function(/*sequence_nr=*/UINT64_MAX) , variable(std::move(variable_)) { - add_input_metadata(variable.type(), variable.sizes()); + add_input_metadata(variable); } auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index e0302e11eff5f..d5a94d49985bc 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -43,7 +43,7 @@ CopySlices::CopySlices( fn(std::move(fn_)) { // Take the next_edges of fn as our own, except for index 0 which goes // to base instead of the view. - add_input_metadata(base_var.type(), base_var.sizes()); + add_input_metadata(base_var); const auto num_outputs = fn->num_outputs(); next_edges_.reserve(num_outputs); add_next_edge(base_var.gradient_edge()); diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index bad48b221eaf0..9f9269c8874e0 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -54,7 +54,7 @@ inline void set_history( if (grad_fn) { if (variable.defined()) { auto output_nr = - grad_fn->add_input_metadata(variable.type(), variable.sizes()); + grad_fn->add_input_metadata(variable); as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr}); } else { grad_fn->add_input_metadata(Function::undefined_input()); diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h new file mode 100644 index 0000000000000..e421441f872cf --- /dev/null +++ b/torch/csrc/autograd/input_metadata.h @@ -0,0 +1,44 @@ +#pragma once + +#include + +#include + +namespace torch { namespace autograd { + +/// A tensor's type and shape. Each Function records the required type and +/// shape of its inputs. If is_valid() is false, then the corresponding input +/// is not used and may be an undefined tensor. +struct InputMetadata { + InputMetadata() = default; + + InputMetadata(const at::Type& type, at::IntList shape, const int64_t device) + : type_{&type} , shape_{shape}, device_{device} { } + + InputMetadata(const at::Tensor& t) + : InputMetadata(t.type(), t.sizes(), t.is_cuda() ? t.get_device() : - 1) { } + + bool is_valid() const { + return type_ != nullptr; + } + + const at::Type& type() const { + AT_ASSERT(type_); + return *type_; + } + + at::IntList shape() const { + return shape_; + } + + int64_t device() const { + return device_; + } + +private: + const at::Type* type_ = nullptr; + at::DimVector shape_; + const int64_t device_ = -1; +}; + +}} diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index e9d29bd0caa68..a1dca1e2eed9d 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -433,7 +433,7 @@ static void _wrap_outputs(THPFunction *self, // to set_history wins. auto var = as_variable(obj, i); if (cdata) { - auto output_nr = cdata->add_input_metadata(var.type(), var.sizes()); + auto output_nr = cdata->add_input_metadata(var); AT_ASSERT(i == (int)output_nr); } set_history(var, i, is_input, is_modified, is_differentiable); diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index 56eb1285028af..339e58cde4e56 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -57,7 +57,7 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject Variable var; if (grad_fn) { auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn); - Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor.type(), tensor.sizes())); + Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor)); var = make_variable(std::move(tensor), std::move(edge)); } else { var = make_variable(std::move(tensor), requires_grad); diff --git a/torch/csrc/autograd/type_and_shape.h b/torch/csrc/autograd/type_and_shape.h index 97da65ec902f5..e69de29bb2d1d 100644 --- a/torch/csrc/autograd/type_and_shape.h +++ b/torch/csrc/autograd/type_and_shape.h @@ -1,33 +0,0 @@ -#pragma once - -#include - -namespace torch { namespace autograd { - -/// A tensor's type and shape. Each Function records the required type and -/// shape of its inputs. If is_valid() is false, then the corresponding input -/// is not used and may be an undefined tensor. -struct TypeAndShape { - TypeAndShape() : type_(nullptr) {} - - TypeAndShape(const at::Type& type, at::IntList shape) - : type_(&type) , shape_(shape) {} - - bool is_valid() const { - return type_ != nullptr; - } - - const at::Type& type() const { - AT_ASSERT(type_); - return *type_; - } - - at::IntList shape() const { - return shape_; - } - - const at::Type* type_; - at::DimVector shape_; -}; - -}} diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index e7f13d10212ca..f8c88c7ddcdde 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -117,13 +117,22 @@ void Variable::Impl::backward( } void Variable::Impl::set_data(Tensor new_data) { - if (new_data.type() != data_.type()) { - scalar_type_ = new_data.type().scalarType(); - backend_ = new_data.type().backend(); - is_variable_ = true; - // Clear grad_accumulator if it exists, since it stores the old type info. - grad_accumulator_.reset(); + // Resets gradient accumulator if metadata is out of date + std::lock_guard lock(mutex_); + auto prior_accumulator = grad_accumulator_.lock(); + if (prior_accumulator) { + const auto prior_device = prior_accumulator->input_metadata(0).device(); + const auto new_device = new_data.is_cuda() ? new_data.get_device() : -1; + + if (new_data.type() != data_.type() || prior_device != new_device) { + grad_accumulator_.reset(); + } } + + // Updates metadata + scalar_type_ = new_data.type().scalarType(); + backend_ = new_data.type().backend(); + is_variable_ = true; data_ = std::move(new_data); } @@ -160,7 +169,10 @@ std::shared_ptr& Variable::ViewImpl::get_grad_fn() { fn->stride = strides().vec(); fn->storage_offset = data_.storage_offset(); fn->set_next_edges(collect_next_edges(base_)); - fn->add_input_metadata(base_.type(), sizes()); + fn->add_input_metadata( + base_.type() + , sizes() // Note: sizes(), not base_.sizes(), is intentional + , base_.is_cuda() ? base_.get_device() : -1); grad_fn_ = std::move(fn); attr_version = current_version; }