diff --git a/aten/src/ATen/THLongStorageView.h b/aten/src/ATen/THLongStorageView.h index 27654e09a189f4..a958a5e0193115 100644 --- a/aten/src/ATen/THLongStorageView.h +++ b/aten/src/ATen/THLongStorageView.h @@ -67,7 +67,7 @@ class THLongStorageView { storage.scalar_type = at::CTypeToScalarType>::to(); storage.refcount = 0; storage.flag = 0; - storage.allocator = nullptr; + storage.allocatorVoidPtr = nullptr; storage.allocatorContext = nullptr; } private: diff --git a/aten/src/ATen/cuda/PinnedMemoryAllocator.cpp b/aten/src/ATen/cuda/PinnedMemoryAllocator.cpp index 52aaa0df346a25..55e4ce0e77ed69 100644 --- a/aten/src/ATen/cuda/PinnedMemoryAllocator.cpp +++ b/aten/src/ATen/cuda/PinnedMemoryAllocator.cpp @@ -3,6 +3,7 @@ #include #include +#include #include diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index ba84579df6d6f3..092681d0898228 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -10,6 +10,7 @@ #include #include "THC/THC.h" +#include #if AT_CUDNN_ENABLED() #include "ATen/cudnn/cudnn-wrapper.h" diff --git a/aten/src/ATen/templates/StorageDerived.cpp b/aten/src/ATen/templates/StorageDerived.cpp index ce9a6211d8ea61..014512f97e2d8a 100644 --- a/aten/src/ATen/templates/StorageDerived.cpp +++ b/aten/src/ATen/templates/StorageDerived.cpp @@ -13,7 +13,7 @@ namespace at { ${Storage}::${Storage}(Context* context): storage(${THStorage}_new(${state})), context(context) {} -${Storage}::${Storage}(Context* context, ${THStorage}* storage): +${Storage}::${Storage}(Context* context, THStorage* storage): storage(storage), context(context) {} ${Storage}::${Storage}(Context* context, size_t storage_size) diff --git a/aten/src/ATen/templates/StorageDerived.h b/aten/src/ATen/templates/StorageDerived.h index 1e50210e98d41b..d97d397c8e7cae 100644 --- a/aten/src/ATen/templates/StorageDerived.h +++ b/aten/src/ATen/templates/StorageDerived.h @@ -16,7 +16,7 @@ struct Allocator; struct ${Storage} final : public Storage { public: explicit ${Storage}(Context* context); - ${Storage}(Context* context, ${THStorage} *wrapped); + ${Storage}(Context* context, THStorage *wrapped); ${Storage}(Context* context, size_t size); ${Storage}(Context* context, size_t size, Allocator* allocator); ${Storage}(Context* context, @@ -50,7 +50,7 @@ struct ${Storage} final : public Storage { protected: friend struct ${Type}; - ${THStorage} *storage; + THStorage *storage; Context* context; }; diff --git a/aten/src/TH/THStorage.cpp b/aten/src/TH/THStorage.cpp index f51d0930be7f99..a73681b7384f28 100644 --- a/aten/src/TH/THStorage.cpp +++ b/aten/src/TH/THStorage.cpp @@ -1,3 +1,5 @@ +#include + #include "THStorage.hpp" #include "generic/THStorage.cpp" @@ -13,6 +15,8 @@ #include "THGenerateHalfType.h" void THStorage_free(THStorage *storage) { + AT_ASSERT(storage->backend == at::kCPU); + if(!storage) return; @@ -21,7 +25,7 @@ void THStorage_free(THStorage *storage) { if(--storage->refcount == 0) { if(storage->flag & TH_STORAGE_FREEMEM) { - storage->allocator->free(storage->allocatorContext, storage->data_ptr); + static_cast(storage->allocatorVoidPtr)->free(storage->allocatorContext, storage->data_ptr); } if(storage->flag & TH_STORAGE_VIEW) { THStorage_free(storage->view); @@ -65,3 +69,30 @@ THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElemen } return copy; } + +THStorage* THStorage_new(at::ScalarType scalar_type) +{ + return THStorage_newWithSize(scalar_type, 0); +} + +THStorage* THStorage_newWithSize(at::ScalarType scalar_type, ptrdiff_t size) +{ + return THStorage_newWithAllocator(scalar_type, size, &THDefaultAllocator, nullptr); +} + +THStorage* THStorage_newWithAllocator(at::ScalarType scalar_type, ptrdiff_t size, + THAllocator *allocator, + void *allocatorContext) +{ + THStorage *storage = static_cast(THAlloc(sizeof(THStorage))); + storage->backend = at::kCPU; + storage->scalar_type = scalar_type; + storage->data_ptr = allocator->malloc(allocatorContext, at::elementSize(scalar_type)*size); + storage->size = size; + new (&storage->refcount) std::atomic(1); + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + storage->allocatorVoidPtr = allocator; + storage->allocatorContext = allocatorContext; + storage->device = INT_MIN; // device is not meaningful on CPU + return storage; +} diff --git a/aten/src/TH/THStorage.hpp b/aten/src/TH/THStorage.hpp index 19e9a98ad9fb09..d5a68ec181a664 100644 --- a/aten/src/TH/THStorage.hpp +++ b/aten/src/TH/THStorage.hpp @@ -5,21 +5,23 @@ #include "THStorage.h" -#include "ATen/ScalarType.h" -#include "ATen/ScalarTypeUtils.h" +#include +#include #include "THTypeConversion.hpp" #include typedef struct THStorage { + at::Backend backend; // kCPU or kCUDA only at::ScalarType scalar_type; void *data_ptr; ptrdiff_t size; std::atomic refcount; char flag; - THAllocator *allocator; + void *allocatorVoidPtr; // Either THDeviceAllocator or THCDeviceAllocator void *allocatorContext; struct THStorage *view; + int device; template inline T * data() const { @@ -36,3 +38,9 @@ typedef struct THStorage return static_cast(this->data_ptr); } } THStorage; + +TH_API THStorage* THStorage_new(at::ScalarType scalar_type); +TH_API THStorage* THStorage_newWithSize(at::ScalarType scalar_type, ptrdiff_t size); +TH_API THStorage* THStorage_newWithAllocator(at::ScalarType scalar_type, ptrdiff_t size, + THAllocator *allocator, + void *allocatorContext); diff --git a/aten/src/TH/generic/THStorage.cpp b/aten/src/TH/generic/THStorage.cpp index cdd70f47f7e6e2..7b163cbbfe491a 100644 --- a/aten/src/TH/generic/THStorage.cpp +++ b/aten/src/TH/generic/THStorage.cpp @@ -21,29 +21,22 @@ size_t THStorage_(elementSize)() THStorage* THStorage_(new)(void) { - return THStorage_(newWithSize)(0); + return THStorage_new(at::CTypeToScalarType>::to()); } THStorage* THStorage_(newWithSize)(ptrdiff_t size) { - return THStorage_(newWithAllocator)(size, &THDefaultAllocator, NULL); + return THStorage_newWithSize(at::CTypeToScalarType>::to(), size); } THStorage* THStorage_(newWithAllocator)(ptrdiff_t size, THAllocator *allocator, void *allocatorContext) { - THStorage *storage = static_cast(THAlloc(sizeof(THStorage))); - storage->scalar_type = at::CTypeToScalarType>::to(); - storage->data_ptr = allocator->malloc(allocatorContext, sizeof(real)*size); - storage->size = size; - new (&storage->refcount) std::atomic(1); - storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - storage->allocator = allocator; - storage->allocatorContext = allocatorContext; - return storage; + return THStorage_newWithAllocator(at::CTypeToScalarType>::to(), size, allocator, allocatorContext); } + THStorage* THStorage_(newWithMapping)(const char *filename, ptrdiff_t size, int flags) { THMapAllocatorContext *ctx = THMapAllocatorContext_new(filename, flags); @@ -142,28 +135,34 @@ THStorage* THStorage_(newWithDataAndAllocator)(real* data, ptrdiff_t size, THAllocator* allocator, void* allocatorContext) { THStorage *storage = static_cast(THAlloc(sizeof(THStorage))); + storage->backend = at::kCPU; storage->scalar_type = at::CTypeToScalarType>::to(); storage->data_ptr = data; storage->size = size; storage->refcount = 1; storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - storage->allocator = allocator; + storage->allocatorVoidPtr = allocator; storage->allocatorContext = allocatorContext; + storage->device = 0; return storage; } void THStorage_(resize)(THStorage *storage, ptrdiff_t size) { + AT_ASSERT(storage->backend == at::kCPU); + + auto* th_allocator = static_cast(storage->allocatorVoidPtr); + if(storage->flag & TH_STORAGE_RESIZABLE) { - if(storage->allocator->realloc == NULL) { + if(th_allocator->realloc == NULL) { /* case when the allocator does not have a realloc defined */ real *old_data = THStorage_(data)(storage); ptrdiff_t old_size = storage->size; if (size == 0) { storage->data_ptr = NULL; } else { - storage->data_ptr = storage->allocator->malloc( + storage->data_ptr = th_allocator->malloc( storage->allocatorContext, sizeof(real)*size); } @@ -176,10 +175,10 @@ void THStorage_(resize)(THStorage *storage, ptrdiff_t size) if (copy_size > 0) { memcpy(THStorage_(data)(storage), old_data, sizeof(real)*copy_size); } - storage->allocator->free(storage->allocatorContext, old_data); + th_allocator->free(storage->allocatorContext, old_data); } } else { - storage->data_ptr = storage->allocator->realloc( + storage->data_ptr = th_allocator->realloc( storage->allocatorContext, THStorage_(data)(storage), sizeof(real)*size); @@ -215,17 +214,19 @@ void THStorage_(swap)(THStorage *storage1, THStorage *storage2) void *data_ptr; ptrdiff_t size; char flag; - THAllocator *allocator; + void *allocatorVoidPtr; void *allocatorContext; struct THStorage *view; + int device; SWAP(data_ptr); SWAP(size); SWAP(flag); // don't swap refcount! - SWAP(allocator); + SWAP(allocatorVoidPtr); SWAP(allocatorContext); SWAP(view); + SWAP(device); #undef SWAP } diff --git a/aten/src/THC/THCGeneral.cpp b/aten/src/THC/THCGeneral.cpp index 114b967f7d309f..6f4fdf070d2f6a 100644 --- a/aten/src/THC/THCGeneral.cpp +++ b/aten/src/THC/THCGeneral.cpp @@ -6,6 +6,7 @@ #include "THCStream.h" #include "THCThreadLocal.h" #include "THCTensorRandom.h" +#include "THCGeneral.hpp" #include #include diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in index 1b4e115a1fab4e..7d9f5fcd952c72 100644 --- a/aten/src/THC/THCGeneral.h.in +++ b/aten/src/THC/THCGeneral.h.in @@ -47,6 +47,7 @@ struct THCRNGState; /* Random number generator state. */ typedef struct THCStream THCStream; typedef struct THCState THCState; +struct THCState; typedef struct _THCDeviceAllocator { cudaError_t (*malloc)( void*, void**, size_t, cudaStream_t); @@ -70,54 +71,6 @@ typedef struct _THCCudaResourcesPerDevice { size_t scratchSpacePerStream; } THCCudaResourcesPerDevice; - -/* Global state to be held in the cutorch table. */ -struct THCState { - struct THCRNGState* rngState; - struct cudaDeviceProp* deviceProperties; - /* Set of all allocated resources. blasHandles and sparseHandles do not have - a default and must be explicitly initialized. We always initialize 1 - blasHandle and 1 sparseHandle but we can use more. - */ - THCCudaResourcesPerDevice* resourcesPerDevice; - /* Captured number of devices upon startup; convenience for bounds checking */ - int numDevices; - int numUserBlasHandles; - int numUserSparseHandles; - - /* Allocator using cudaMallocHost. */ - THAllocator* cudaHostAllocator; - THAllocator* cudaUVAAllocator; - THCDeviceAllocator* cudaDeviceAllocator; - - /* Index of the current selected BLAS handle. The actual BLAS handle used - depends on the current device. */ - THCThreadLocal/**/ currentPerDeviceBlasHandle; - /* Index of the current selected sparse handle. The actual sparse handle used - depends on the current device. */ - THCThreadLocal/**/ currentPerDeviceSparseHandle; - /* Array of thread locals containing the current stream for each device */ - THCThreadLocal* currentStreams; - - /* Table of enabled peer-to-peer access between directed pairs of GPUs. - If i accessing allocs on j is enabled, p2pAccess[i][j] is 1; 0 otherwise. */ - int** p2pAccessEnabled; - - /* Is direct cross-kernel p2p access allowed? Normally, only cross-GPU - copies are allowed via p2p if p2p access is enabled at all for - the pair of GPUs in question, but if this flag is true, then - all cross-GPU access checks are disabled, allowing kernels to - directly access memory on another GPUs. - Note that p2p access must exist and be enabled for the pair of - GPUs in question. */ - int p2pKernelAccessEnabled; - - void (*cutorchGCFunction)(void *data); - void *cutorchGCData; - ptrdiff_t heapSoftmax; - ptrdiff_t heapDelta; -}; - THC_API THCState* THCState_alloc(void); THC_API void THCState_free(THCState* state); diff --git a/aten/src/THC/THCGeneral.hpp b/aten/src/THC/THCGeneral.hpp new file mode 100644 index 00000000000000..495e1cc338b776 --- /dev/null +++ b/aten/src/THC/THCGeneral.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "THCGeneral.h" + +/* Global state of THC. */ +struct THCState { + struct THCRNGState* rngState; + struct cudaDeviceProp* deviceProperties; + /* Set of all allocated resources. blasHandles and sparseHandles do not have + a default and must be explicitly initialized. We always initialize 1 + blasHandle and 1 sparseHandle but we can use more. + */ + THCCudaResourcesPerDevice* resourcesPerDevice; + /* Captured number of devices upon startup; convenience for bounds checking */ + int numDevices; + int numUserBlasHandles; + int numUserSparseHandles; + + /* Allocator using cudaMallocHost. */ + THAllocator* cudaHostAllocator; + THAllocator* cudaUVAAllocator; + THCDeviceAllocator* cudaDeviceAllocator; + + /* Index of the current selected BLAS handle. The actual BLAS handle used + depends on the current device. */ + THCThreadLocal/**/ currentPerDeviceBlasHandle; + /* Index of the current selected sparse handle. The actual sparse handle used + depends on the current device. */ + THCThreadLocal/**/ currentPerDeviceSparseHandle; + /* Array of thread locals containing the current stream for each device */ + THCThreadLocal* currentStreams; + + /* Table of enabled peer-to-peer access between directed pairs of GPUs. + If i accessing allocs on j is enabled, p2pAccess[i][j] is 1; 0 otherwise. */ + int** p2pAccessEnabled; + + /* Is direct cross-kernel p2p access allowed? Normally, only cross-GPU + copies are allowed via p2p if p2p access is enabled at all for + the pair of GPUs in question, but if this flag is true, then + all cross-GPU access checks are disabled, allowing kernels to + directly access memory on another GPUs. + Note that p2p access must exist and be enabled for the pair of + GPUs in question. */ + int p2pKernelAccessEnabled; + + void (*cutorchGCFunction)(void *data); + void *cutorchGCData; + ptrdiff_t heapSoftmax; + ptrdiff_t heapDelta; +}; diff --git a/aten/src/THC/THCStorage.cpp b/aten/src/THC/THCStorage.cpp index 6547beb9d6ede8..93d58bd379bf88 100644 --- a/aten/src/THC/THCStorage.cpp +++ b/aten/src/THC/THCStorage.cpp @@ -34,9 +34,10 @@ THCStorage* THCStorage_newWithAllocator(THCState *state, THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); memset(storage, 0, sizeof(THCStorage)); new (&storage->refcount) std::atomic(1); + storage->backend = at::kCUDA; storage->scalar_type = scalar_type; storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - storage->allocator = allocator; + storage->allocatorVoidPtr = allocator; storage->allocatorContext = allocatorContext; storage->size = size; storage->device = device; @@ -67,14 +68,17 @@ void THCStorage_retain(THCState *state, THCStorage *self) void THCStorage_free(THCState *state, THCStorage *self) { + AT_ASSERT(self->backend == at::kCUDA); + if(!(self->flag & TH_STORAGE_REFCOUNTED)) return; if (--self->refcount == 0) { if(self->flag & TH_STORAGE_FREEMEM) { + auto* thc_device_allocator = static_cast(self->allocatorVoidPtr); THCudaCheck( - (*self->allocator->free)(self->allocatorContext, self->data_ptr)); + (*thc_device_allocator->free)(self->allocatorContext, self->data_ptr)); } if(self->flag & TH_STORAGE_VIEW) { THCStorage_free(state, self->view); @@ -86,8 +90,10 @@ void THCStorage_free(THCState *state, THCStorage *self) void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) { + AT_ASSERT(self->backend == at::kCUDA); + THArgCheck(size >= 0, 2, "invalid size"); - THAssert(self->allocator != NULL); + THAssert(self->allocatorVoidPtr != NULL); int device; THCudaCheck(cudaGetDevice(&device)); @@ -96,9 +102,11 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) size_t elementSize = at::elementSize(self->scalar_type); - if (self->allocator->realloc) { + auto* thc_device_allocator = static_cast(self->allocatorVoidPtr); + + if (thc_device_allocator->realloc) { void * data_ptr = self->data_ptr; - cudaError_t err = (*self->allocator->realloc)( + cudaError_t err = (*thc_device_allocator->realloc)( self->allocatorContext, (void**)&(data_ptr), self->size * elementSize, @@ -115,7 +123,7 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) { if(self->flag & TH_STORAGE_FREEMEM) { THCudaCheck( - (*self->allocator->free)(self->allocatorContext, self->data_ptr)); + (*thc_device_allocator->free)(self->allocatorContext, self->data_ptr)); } self->data_ptr = NULL; self->size = 0; @@ -125,7 +133,7 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) { void *data = NULL; cudaError_t err = - (*self->allocator->malloc)(self->allocatorContext, + (*thc_device_allocator->malloc)(self->allocatorContext, (void**)&(data), size * elementSize, THCState_getCurrentStreamOnDevice(state, device)); @@ -142,7 +150,7 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) THCState_getCurrentStream(state))); if(self->flag & TH_STORAGE_FREEMEM) { THCudaCheck( - (*self->allocator->free)(self->allocatorContext, self->data_ptr)); + (*thc_device_allocator->free)(self->allocatorContext, self->data_ptr)); } } @@ -154,4 +162,4 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) int THCStorage_getDevice(THCState* state, const THCStorage* storage) { return storage->device; -} \ No newline at end of file +} diff --git a/aten/src/THC/THCStorage.hpp b/aten/src/THC/THCStorage.hpp index 67876386eda9af..f626d75c9daf16 100644 --- a/aten/src/THC/THCStorage.hpp +++ b/aten/src/THC/THCStorage.hpp @@ -4,6 +4,7 @@ // read Note [TH abstraction violation] #include "THCStorage.h" +#include #include "ATen/ScalarType.h" #include "ATen/ScalarTypeUtils.h" @@ -16,34 +17,6 @@ struct CTypeToScalarType<__half> : public CTypeToScalarType {}; } -typedef struct THCStorage -{ - at::ScalarType scalar_type; - void *data_ptr; - ptrdiff_t size; - std::atomic refcount; - char flag; - THCDeviceAllocator *allocator; - void *allocatorContext; - struct THCStorage *view; - int device; - - template - inline T * data() const { - auto scalar_type_T = at::CTypeToScalarType::to(); - if (scalar_type != scalar_type_T) { - AT_ERROR("Attempt to access Storage having data type ", at::toString(scalar_type), - " as data type ", at::toString(scalar_type_T)); - } - return unsafe_data(); - } - - template - inline T * unsafe_data() const { - return static_cast(this->data_ptr); - } -} THCStorage; - THC_API THCStorage* THCStorage_new(THCState *state, at::ScalarType scalar_type); THC_API THCStorage* THCStorage_newWithSize(THCState *state, at::ScalarType scalar_type, ptrdiff_t size); diff --git a/aten/src/THC/generic/THCStorage.cpp b/aten/src/THC/generic/THCStorage.cpp index acaf1914fd4067..6be2aa9b009468 100644 --- a/aten/src/THC/generic/THCStorage.cpp +++ b/aten/src/THC/generic/THCStorage.cpp @@ -108,12 +108,13 @@ THCStorage* THCStorage_(newWithDataAndAllocator)( THCDeviceAllocator *allocator, void *allocatorContext) { THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); memset(storage, 0, sizeof(THCStorage)); + storage->backend = at::kCUDA; storage->scalar_type = at::CTypeToScalarType::to(); storage->data_ptr = data; storage->size = size; storage->refcount = 1; storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - storage->allocator = allocator; + storage->allocatorVoidPtr = allocator; storage->allocatorContext = allocatorContext; int device; if (data) { diff --git a/aten/src/THC/generic/THCStorage.h b/aten/src/THC/generic/THCStorage.h index 0d443e7fcdae78..7899b45d127901 100644 --- a/aten/src/THC/generic/THCStorage.h +++ b/aten/src/THC/generic/THCStorage.h @@ -6,7 +6,7 @@ #define TH_STORAGE_RESIZABLE 2 #define TH_STORAGE_FREEMEM 4 -typedef struct THCStorage THCStorage; +#define THCStorage THStorage // These used to be distinct types; for some measure of backwards compatibility and documentation // alias these to the single THCStorage type. diff --git a/caffe2/operators/reduction_ops.cc b/caffe2/operators/reduction_ops.cc index 306b85fd96665f..6f043eb4c5678a 100644 --- a/caffe2/operators/reduction_ops.cc +++ b/caffe2/operators/reduction_ops.cc @@ -296,7 +296,9 @@ bool SumElementsGradientOp::RunOnDevice() #endif { auto& X = Input(0); - const auto& sum_grad = Input(1); + // Copy Input(1) from Context to CPUContext + CPUContext context; + TensorCPU sum_grad(Input(1), &context); auto* dX = Output(0); dX->ResizeLike(X); DCHECK_EQ(sum_grad.size(), 1); diff --git a/caffe2/opt/backend_cutting.cc b/caffe2/opt/backend_cutting.cc index 058016b63cb00c..2e292f031251a5 100644 --- a/caffe2/opt/backend_cutting.cc +++ b/caffe2/opt/backend_cutting.cc @@ -43,9 +43,9 @@ std::string ShowNode(NodeRef node) { return MakeString("Tensor: ", nn_tensor->getName()); } else if (nn::is(node)) { const auto* nn_op = nn::get(node); - const auto* op_def = dyn_cast(nn_op->getAnnotation())->getOperatorDef(); - CAFFE_ENFORCE(op_def); - return MakeString("Op: ", op_def->type()); + const auto& op_def = + dyn_cast(nn_op->getAnnotation())->getOperatorDef(); + return MakeString("Op: ", op_def.type()); } else { CAFFE_THROW("Known node"); } @@ -106,8 +106,9 @@ void Explore( if (nn::is(node)) { const auto* nn_op = nn::get(node); - const auto* op_def = dyn_cast(nn_op->getAnnotation())->getOperatorDef(); - bool wanted = context->predicate(*op_def); + const auto& op_def = + dyn_cast(nn_op->getAnnotation())->getOperatorDef(); + bool wanted = context->predicate(op_def); wanted = context->find_supported ? wanted : (!wanted); if (!wanted) { context->frontier.emplace(node); @@ -190,8 +191,9 @@ caffe2::NetDef ConvertToC2Net( if (nn::is(node)) { const auto* nn_op = nn::get(node); assert(isa(nn_op->getAnnotation()) && "Cannot get caffe2 op from NNOp"); - const auto* op_def = dyn_cast(nn_op->getAnnotation())->getOperatorDef(); - net.add_op()->CopyFrom(*op_def); + const auto& op_def = + dyn_cast(nn_op->getAnnotation())->getOperatorDef(); + net.add_op()->CopyFrom(op_def); } } for (const auto kv : sub.external_input_refs) { @@ -282,7 +284,7 @@ void ReplaceSubgraph( g->createEdge(op_node, tensor_node); } - op_node->resetData(convertToNeuralNetOperator(&op)); + op_node->resetData(convertToNeuralNetOperator(op)); } } diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc index 5f791013dfd8bc..40dddffad7e5c3 100644 --- a/caffe2/opt/converter.cc +++ b/caffe2/opt/converter.cc @@ -89,13 +89,13 @@ std::vector getDilations(std::map argMap) { namespace caffe2 { -std::unique_ptr -convertToNeuralNetOperator(caffe2::OperatorDef* op) { - auto argMap = getArgumentsFromOperator(*op); +std::unique_ptr convertToNeuralNetOperator( + const caffe2::OperatorDef& op) { + auto argMap = getArgumentsFromOperator(op); std::unique_ptr nnOp; - if (op->type() == "Conv") { + if (op.type() == "Conv") { auto kernelShape = getKernelShape(argMap); nnOp = util::make_unique(kernelShape); auto c = dyn_cast(nnOp.get()); @@ -106,29 +106,29 @@ convertToNeuralNetOperator(caffe2::OperatorDef* op) { } - if (op->type() == "Relu") { + if (op.type() == "Relu") { nnOp = util::make_unique(); } - if (op->type() == "AveragePool") { + if (op.type() == "AveragePool") { auto kernelShape = getKernelShape(argMap); nnOp = util::make_unique(kernelShape); } - if (op->type() == "MaxPool") { + if (op.type() == "MaxPool") { auto kernelShape = getKernelShape(argMap); nnOp = util::make_unique(kernelShape); } - if (op->type() == "Sum") { + if (op.type() == "Sum") { nnOp = util::make_unique(); } - if (op->type() == "SpatialBN") { + if (op.type() == "SpatialBN") { nnOp = util::make_unique(); } - if (op->type() == "Concat") { + if (op.type() == "Concat") { nnOp = util::make_unique(); auto c = dyn_cast(nnOp.get()); if (argMap.count("axis")) { @@ -143,15 +143,15 @@ convertToNeuralNetOperator(caffe2::OperatorDef* op) { } } - if (op->type() == "Flatten") { + if (op.type() == "Flatten") { nnOp = util::make_unique(); } - if (op->type() == "BatchGather") { + if (op.type() == "BatchGather") { nnOp = util::make_unique(); } - if (op->type() == "BatchMatMul") { + if (op.type() == "BatchMatMul") { nnOp = util::make_unique(); auto c = dyn_cast(nnOp.get()); if (argMap.count("trans_a")) { @@ -172,7 +172,7 @@ convertToNeuralNetOperator(caffe2::OperatorDef* op) { } if (!nnOp) { - nnOp = util::make_unique(op->type()); + nnOp = util::make_unique(op.type()); } // Generic attributes associated with Ops here @@ -181,7 +181,7 @@ convertToNeuralNetOperator(caffe2::OperatorDef* op) { auto annotation = util::make_unique(); annotation->setOperatorDef(op); - auto device_name = op->device_option().node_name(); + auto device_name = op.device_option().node_name(); if (device_name != "") { annotation->setDevice(device_name); } @@ -379,7 +379,7 @@ repr::NNModule convertToNNModule(caffe2::NetDef &net, std::unordered_mapresetData(convertToNeuralNetOperator(&op)); + opNode->resetData(convertToNeuralNetOperator(op)); auto currentBasicBlock = bbNode->mutableData()->get(); currentBasicBlock->pushInstructionNode(opNode); } @@ -444,7 +444,7 @@ caffe2::OperatorDef convertToOperatorDef(repr::NNGraph::NodeRef instrNode) { } else { switch (annotation->getKind()) { case repr::Annotation::AnnotationKind::Caffe2: - op = *dyn_cast(annotation)->getOperatorDef(); + op = dyn_cast(annotation)->getOperatorDef(); break; default: op.set_type("__NOMNIGRAPH_CONVERSION_ERROR__"); diff --git a/caffe2/opt/converter.h b/caffe2/opt/converter.h index 04fe9857eb3373..74f2f3e03c7373 100644 --- a/caffe2/opt/converter.h +++ b/caffe2/opt/converter.h @@ -1,11 +1,12 @@ #ifndef CAFFE2_OPT_CONVERTER_H #define CAFFE2_OPT_CONVERTER_H +#include "caffe2/core/common.h" +#include "caffe2/core/logging.h" +#include "caffe2/proto/caffe2.pb.h" #include "nomnigraph/Graph/Graph.h" #include "nomnigraph/Representations/ControlFlow.h" #include "nomnigraph/Representations/NeuralNet.h" -#include "caffe2/core/common.h" -#include "caffe2/proto/caffe2.pb.h" #include @@ -21,16 +22,21 @@ class Caffe2Annotation : public nom::repr::Annotation { void setDevice(std::string device) { Device = device; } const std::string getDevice() const { return Device; } - void setOperatorDef(caffe2::OperatorDef* opDef) { + void setOperatorDef(const caffe2::OperatorDef& opDef) { OpDef = opDef; + OpDefExists = true; } - const caffe2::OperatorDef* getOperatorDef() const { - assert(OpDef && "OperatorDef was never set. Use Caffe2Annotation::setOperatorDef."); + const caffe2::OperatorDef& getOperatorDef() const { + CAFFE_ENFORCE( + OpDefExists, + "OperatorDef was never set. Use Caffe2Annotation::setOperatorDef."); return OpDef; } caffe2::OperatorDef* getMutableOperatorDef() { - assert(OpDef && "OperatorDef was never set. Use Caffe2Annotation::setOperatorDef."); - return OpDef; + CAFFE_ENFORCE( + OpDefExists, + "OperatorDef was never set. Use Caffe2Annotation::setOperatorDef."); + return &OpDef; } static bool classof(const Annotation *A) { @@ -39,7 +45,8 @@ class Caffe2Annotation : public nom::repr::Annotation { private: std::string Device = ""; - caffe2::OperatorDef* OpDef = nullptr; + caffe2::OperatorDef OpDef; + bool OpDefExists = false; }; nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, std::unordered_map* blobMapOut = nullptr); @@ -51,7 +58,8 @@ caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&); // are not reflected in changes to external_input or external_output. caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&, const caffe2::NetDef& oldNet); -std::unique_ptr convertToNeuralNetOperator(caffe2::OperatorDef* op); +std::unique_ptr convertToNeuralNetOperator( + const caffe2::OperatorDef& op); } // namespace caffe2 diff --git a/caffe2/opt/mobile.cc b/caffe2/opt/mobile.cc index 8cc43811e9e2e6..6d0006818789bb 100644 --- a/caffe2/opt/mobile.cc +++ b/caffe2/opt/mobile.cc @@ -109,14 +109,14 @@ void fuseNNPACKConvRelu(repr::NNModule* nn) { if (!annotation || !isa(annotation)) { return false; } - const auto* op = dyn_cast(annotation)->getOperatorDef(); + const auto& op = dyn_cast(annotation)->getOperatorDef(); // We only want to fuse for fast NNPACK convs - if (op->engine() != "NNPACK") { + if (op.engine() != "NNPACK") { return false; } caffe2::string algo = "AUTO"; - for (const auto arg : op->arg()) { + for (const auto arg : op.arg()) { if (arg.name() == "algo") { algo = arg.s(); } diff --git a/caffe2/opt/optimize_ideep.cc b/caffe2/opt/optimize_ideep.cc index 7a7d963609e074..5a6643c2aa67ae 100644 --- a/caffe2/opt/optimize_ideep.cc +++ b/caffe2/opt/optimize_ideep.cc @@ -15,16 +15,16 @@ void OptimizeForIdeep(repr::NNModule* nn) { if (!annotation || !isa(annotation)) { return false; } - const auto* op = dyn_cast(annotation)->getOperatorDef(); + const auto& op = dyn_cast(annotation)->getOperatorDef(); // We only want to fuse for IDEEP convs - if (op->device_option().device_type() != DeviceType::IDEEP) { + if (op.device_option().device_type() != DeviceType::IDEEP) { return false; } // IDEEP doesn't support fusion group conv int group = - ArgumentHelper::GetSingleArgument(*op, "group", 1); + ArgumentHelper::GetSingleArgument(op, "group", 1); if (group != 1) { return false; } diff --git a/test/common_nn.py b/test/common_nn.py index 7f6c2ac202137e..2f5b6a213b91fa 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -21,11 +21,12 @@ PRECISION = 1e-5 -def get_size_average(m): +def get_reduction(m): result = getattr(m, 'reduction', None) - if result is not None: - return result is 'elementwise_mean' - return getattr(m, 'sizeAverage', None) + if result is None: + result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False) + assert result is not None + return result def get_weight(m): @@ -246,19 +247,19 @@ def get_weight(m): ] -def kldivloss_reference(input, target, size_average=True, reduce=True): +def kldivloss_reference(input, target, reduction='elementwise_mean'): safe_target = target * (target > 0).type_as(target) safe_target_log = (safe_target + (target <= 0).type_as(target)).log() result = safe_target * (safe_target_log - input) - if reduce and size_average: + if reduction == 'elementwise_mean': return result.mean() - elif reduce: + elif reduction == 'sum': return result.sum() return result def nlllossNd_reference(input, target, weight=None, ignore_index=-100, - size_average=True, reduce=True): + reduction='elementwise_mean'): assert input.dim() >= 3 N = input.size(0) C = input.size(1) @@ -276,15 +277,15 @@ def nlllossNd_reference(input, target, weight=None, ignore_index=-100, output[tup] = -input[tuple(input_index)] * norm total_weight += norm - if reduce and size_average: + if reduction == 'elementwise_mean': return output.sum() / total_weight - elif reduce: + elif reduction == 'sum': return output.sum() return output def nllloss_reference(input, target, weight=None, ignore_index=-100, - size_average=True, reduce=True): + reduction='elementwise_mean'): def nll_loss_helper(input, target, weight, ignore_index): if target == ignore_index: @@ -297,22 +298,22 @@ def nll_loss_helper(input, target, weight, ignore_index): for i, t in zip(input, target)] losses, weights = zip(*losses_and_weights) losses_tensor = input.new_tensor(losses) - if reduce and size_average: + if reduction == 'elementwise_mean': return sum(losses_tensor) / sum(weights) - elif reduce: + elif reduction == 'sum': return sum(losses_tensor) else: return losses_tensor -def smoothl1loss_reference(input, target, size_average=True, reduce=True): +def smoothl1loss_reference(input, target, reduction='elementwise_mean'): abs_diff = (input - target).abs() ge_one_mask = (abs_diff >= 1).type_as(abs_diff) lt_one_mask = (abs_diff < 1).type_as(abs_diff) output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() - elif reduce: + elif reduction == 'sum': return output.sum() return output @@ -333,7 +334,7 @@ def _multilabelmarginloss_reference(input, target): return sum -def multilabelmarginloss_reference(input, target, size_average=True, reduce=True): +def multilabelmarginloss_reference(input, target, reduction='elementwise_mean'): if input.dim() == 1: n = 1 dim = input.size(0) @@ -346,30 +347,30 @@ def multilabelmarginloss_reference(input, target, size_average=True, reduce=True for i in range(0, n): output[i] = _multilabelmarginloss_reference(input[i], target[i]) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() / dim - elif reduce: + elif reduction == 'sum': return output.sum() / dim return output / dim -def hingeembeddingloss_reference(input, target, margin=1.0, size_average=True, reduce=True): +def hingeembeddingloss_reference(input, target, margin=1.0, reduction='elementwise_mean'): margin_clamp = (margin - input).clamp(min=0).type_as(input) output = torch.where(target == 1, input, margin_clamp) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() - elif reduce: + elif reduction == 'sum': return output.sum() return output -def softmarginloss_reference(input, target, size_average=True, reduce=True): +def softmarginloss_reference(input, target, reduction='elementwise_mean'): output = (1 + (-input * target).exp()).log() - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() - elif reduce: + elif reduction == 'sum': return output.sum() return output @@ -385,8 +386,7 @@ def _multimarginloss_reference(input, target_idx, p, margin, weight): return output -def multimarginloss_reference(input, target, p=1, margin=1, weight=None, size_average=True, - reduce=True): +def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='elementwise_mean'): if input.dim() == 1: n = 1 dim = input.size(0) @@ -400,14 +400,14 @@ def multimarginloss_reference(input, target, p=1, margin=1, weight=None, size_av for x in range(0, n): output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() / dim - elif reduce: + elif reduction == 'sum': return output.sum() / dim return output / dim -def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True): +def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='elementwise_mean'): def _cos(a, b): cos = a.new(a.size(0)) for i in range(0, a.size(0)): @@ -416,15 +416,15 @@ def _cos(a, b): output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0)) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() - elif reduce: + elif reduction == 'sum': return output.sum() return output def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, - size_average=True, reduce=True): + reduction='elementwise_mean'): d_p = torch.pairwise_distance(anchor, positive, p, eps) d_n = torch.pairwise_distance(anchor, negative, p, eps) if swap: @@ -432,18 +432,18 @@ def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps d_n = torch.min(d_n, d_s) output = torch.clamp(margin + d_p - d_n, min=0.0) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() - elif reduce: + elif reduction == 'sum': return output.sum() return output -def marginrankingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True): +def marginrankingloss_reference(input1, input2, target, margin=0, reduction='elementwise_mean'): output = (-target * (input1 - input2) + margin).clamp(min=0) - if reduce and size_average: + if reduction == 'elementwise_mean': return output.mean() - elif reduce: + elif reduction == 'sum': return output.sum() return output @@ -476,12 +476,12 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_fn=lambda: torch.rand(15, 10).log(), target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), reference_fn=lambda i, t, m: - nllloss_reference(i, t, size_average=get_size_average(m)), - check_no_size_average=True + nllloss_reference(i, t, reduction=get_reduction(m)), + check_sum_reduction=True ), dict( module_name='NLLLoss', - constructor_args=(None, True, 2), + constructor_args=(None, None, 2), input_fn=lambda: torch.rand(15, 10).log(), target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2), @@ -498,7 +498,7 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T ), dict( module_name='NLLLoss', - constructor_args_fn=lambda: (torch.rand(10), True, 2), + constructor_args_fn=lambda: (torch.rand(10), None, 2), input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), reference_fn=lambda i, t, m: @@ -507,7 +507,7 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T ), dict( module_name='NLLLoss', - constructor_args_fn=lambda: (torch.rand(10), True, -1), + constructor_args_fn=lambda: (torch.rand(10), None, -1), input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1, reference_fn=lambda i, t, m: @@ -519,22 +519,23 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_fn=lambda: torch.rand(10, 10).log(), target_fn=lambda: torch.rand(10, 10), reference_fn=lambda i, t, m: - kldivloss_reference(i, t, get_size_average(m), reduce=True), - check_no_size_average=True, + kldivloss_reference(i, t, get_reduction(m)), + check_sum_reduction=True, ), dict( module_name='MSELoss', input_size=(2, 3, 4, 5), target_size=(2, 3, 4, 5), - reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1), - check_no_size_average=True, + reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel() + if get_reduction(m) == 'elementwise_mean' else 1)), + check_sum_reduction=True, ), dict( module_name='BCELoss', input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.randn(15, 10).gt(0).double(), reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() / - (i.numel() if get_size_average(m) else 1), + (i.numel() if get_reduction(m) else 1), check_gradgrad=False, ), dict( @@ -543,7 +544,7 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.randn(15, 10).gt(0).double(), reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() / - (i.numel() if get_size_average(m) else 1), + (i.numel() if get_reduction(m) else 1), desc='weights', check_gradgrad=False, ), @@ -564,8 +565,8 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(10,), target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1), reference_fn=lambda i, t, m: - hingeembeddingloss_reference(i, t, size_average=get_size_average(m)), - check_no_size_average=True, + hingeembeddingloss_reference(i, t, reduction=get_reduction(m)), + check_sum_reduction=True, ), dict( module_name='HingeEmbeddingLoss', @@ -573,18 +574,18 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(10,), target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1), reference_fn=lambda i, t, m: - hingeembeddingloss_reference(i, t, margin=0.5, size_average=get_size_average(m)), + hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)), desc='margin', - check_no_size_average=True, + check_sum_reduction=True, ), dict( module_name='MultiLabelMarginLoss', input_size=(10,), target_fn=lambda: torch.rand(10).mul(10).floor().long(), reference_fn=lambda i, t, m: - multilabelmarginloss_reference(i, t, size_average=get_size_average(m)), + multilabelmarginloss_reference(i, t, reduction=get_reduction(m)), desc="1d", - check_no_size_average=True, + check_sum_reduction=True, check_gradgrad=False, ), dict( @@ -592,8 +593,8 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(5, 10), target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(), reference_fn=lambda i, t, m: - multilabelmarginloss_reference(i, t, size_average=get_size_average(m)), - check_no_size_average=True, + multilabelmarginloss_reference(i, t, reduction=get_reduction(m)), + check_sum_reduction=True, check_gradgrad=False, ), dict( @@ -608,8 +609,8 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(5, 10), target_fn=lambda: torch.rand(5).mul(8).floor().long(), reference_fn=lambda i, t, m: - multimarginloss_reference(i, t, size_average=get_size_average(m)), - check_no_size_average=True, + multimarginloss_reference(i, t, reduction=get_reduction(m)), + check_sum_reduction=True, check_gradgrad=False, ), dict( @@ -617,9 +618,9 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(10,), target_fn=lambda: torch.rand(1).mul(8).floor().long(), reference_fn=lambda i, t, m: - multimarginloss_reference(i, t, size_average=get_size_average(m)), + multimarginloss_reference(i, t, reduction=get_reduction(m)), desc='1d', - check_no_size_average=True, + check_sum_reduction=True, check_gradgrad=False, ), dict( @@ -628,9 +629,9 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.rand(5).mul(8).floor().long(), reference_fn=lambda i, t, m: - multimarginloss_reference(i, t, p=2, size_average=get_size_average(m)), + multimarginloss_reference(i, t, p=2, reduction=get_reduction(m)), desc='p', - check_no_size_average=True, + check_sum_reduction=True, check_gradgrad=False, ), dict( @@ -640,9 +641,9 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(5, 10), target_fn=lambda: torch.rand(5).mul(8).floor().long(), reference_fn=lambda i, t, m: - multimarginloss_reference(i, t, margin=0.5, size_average=get_size_average(m)), + multimarginloss_reference(i, t, margin=0.5, reduction=get_reduction(m)), desc='margin', - check_no_size_average=True, + check_sum_reduction=True, check_gradgrad=False, ), dict( @@ -652,34 +653,34 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_size=(5, 10), target_fn=lambda: torch.rand(5).mul(8).floor().long(), reference_fn=lambda i, t, m: - multimarginloss_reference(i, t, weight=get_weight(m), size_average=get_size_average(m)), + multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)), desc='weights', - check_no_size_average=True, + check_sum_reduction=True, check_gradgrad=False, ), dict( module_name='SmoothL1Loss', input_size=(5, 10), target_size=(5, 10), - check_no_size_average=True, + check_sum_reduction=True, reference_fn=lambda i, t, m: - smoothl1loss_reference(i, t, size_average=get_size_average(m)), + smoothl1loss_reference(i, t, reduction=get_reduction(m)), ), dict( module_name='SoftMarginLoss', input_size=(5, 5), target_fn=lambda: torch.randn(5, 5).sign(), reference_fn=lambda i, t, m: - softmarginloss_reference(i, t, size_average=get_size_average(m)), - check_no_size_average=True, + softmarginloss_reference(i, t, reduction=get_reduction(m)), + check_sum_reduction=True, ), dict( module_name='CosineEmbeddingLoss', input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)), target_fn=lambda: torch.randn(15).sign(), reference_fn=lambda i, t, m: - cosineembeddingloss_reference(i[0], i[1], t, size_average=get_size_average(m)), - check_no_size_average=True, + cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)), + check_sum_reduction=True, ), dict( module_name='CosineEmbeddingLoss', @@ -687,17 +688,17 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)), target_fn=lambda: torch.randn(15).sign(), reference_fn=lambda i, t, m: - cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, size_average=get_size_average(m)), + cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)), desc='margin', - check_no_size_average=True, + check_sum_reduction=True, ), dict( module_name='MarginRankingLoss', input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)), target_fn=lambda: torch.randn(50).sign(), reference_fn=lambda i, t, m: - marginrankingloss_reference(i[0], i[1], t, size_average=get_size_average(m)), - check_no_size_average=True, + marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)), + check_sum_reduction=True, ), dict( module_name='MarginRankingLoss', @@ -705,9 +706,9 @@ def marginrankingloss_reference(input1, input2, target, margin=0, size_average=T input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)), target_fn=lambda: torch.randn(50).sign(), reference_fn=lambda i, t, m: - marginrankingloss_reference(i[0], i[1], t, margin=0.5, size_average=get_size_average(m)), + marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)), desc='margin', - check_no_size_average=True, + check_sum_reduction=True, ), ] diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index ac31a6517c88c8..66400a998d2d0b 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -213,4 +213,34 @@ TEST_CASE("sequential") { .sum() .toCFloat() == 10); } + + SECTION("extend() pushes modules from other Sequential") { + struct A : torch::nn::Module { int forward(int x) { return x; } }; + struct B : torch::nn::Module { int forward(int x) { return x; } }; + struct C : torch::nn::Module { int forward(int x) { return x; } }; + struct D : torch::nn::Module { int forward(int x) { return x; } }; + Sequential a(A{}, B{}); + Sequential b(C{}, D{}); + a.extend(b); + + REQUIRE(a.size() == 4); + REQUIRE(a[0]->is()); + REQUIRE(a[1]->is()); + REQUIRE(a[2]->is()); + REQUIRE(a[3]->is()); + + REQUIRE(b.size() == 2); + REQUIRE(b[0]->is()); + REQUIRE(b[1]->is()); + + std::vector> c = {std::make_shared(), + std::make_shared()}; + b.extend(c); + + REQUIRE(b.size() == 4); + REQUIRE(b[0]->is()); + REQUIRE(b[1]->is()); + REQUIRE(b[2]->is()); + REQUIRE(b[3]->is()); + } } diff --git a/test/test_jit.py b/test/test_jit.py index 6b0e276088953f..f186f42d48e3b8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -70,10 +70,10 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) - ingate = F.sigmoid(ingate) - forgetgate = F.sigmoid(forgetgate) + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) cellgate = F.tanh(cellgate) - outgate = F.sigmoid(outgate) + outgate = torch.sigmoid(outgate) cy = (forgetgate * cx) + (ingate * cellgate) hy = outgate * F.tanh(cy) @@ -4476,7 +4476,7 @@ def reparameterize(self, mu, logvar): def decode(self, z): h3 = F.relu(self.fc3(z)) - return F.sigmoid(self.fc4(h3)) + return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) diff --git a/test/test_nn.py b/test/test_nn.py index 8c52bd1ab30a61..c25863edb15414 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -34,7 +34,7 @@ from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ TEST_CUDNN_VERSION from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \ - module_tests, criterion_tests, loss_reference_fns, get_size_average, \ + module_tests, criterion_tests, loss_reference_fns, get_reduction, \ get_weight, smoothl1loss_reference, kldivloss_reference @@ -4270,8 +4270,8 @@ def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss(self): self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target)) - self.assertEqual(nn.BCEWithLogitsLoss(reduce=False)(output, target), - nn.BCELoss(reduce=False)(sigmoid(output), target)) + self.assertEqual(nn.BCEWithLogitsLoss(reduction='none')(output, target), + nn.BCELoss(reduction='none')(sigmoid(output), target)) weight = torch.rand(1, dtype=torch.float) self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target)) @@ -4279,7 +4279,7 @@ def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss(self): def test_bce_with_logits_has_correct_grad_at_zero(self): output = torch.zeros(3, 1, requires_grad=True) target = torch.zeros(3, 1) - nn.BCEWithLogitsLoss(size_average=False)(output, target).backward() + nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward() expected_grad = torch.empty(3, 1).fill_(0.5) self.assertEqual(output.grad, expected_grad) @@ -4330,10 +4330,9 @@ def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self): output = torch.zeros(3, 1, requires_grad=True) target = torch.zeros(3, 1) pos_weight = torch.ones(3, 1) - nn.BCEWithLogitsLoss(pos_weight=pos_weight, size_average=False)(output, target).backward() + nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward() expected_grad = torch.empty(3, 1).fill_(0.5) grad = output.grad - print(grad) self.assertEqual(grad, expected_grad) def test_bce_loss_broadcasts_weights(self): @@ -4560,36 +4559,37 @@ def test_cosine_embedding_loss_no_reduce(self): input2 = torch.randn(15, 10, requires_grad=True) target = torch.randn(15).sign() self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss( - x, y, z, reduce=False), (input1, input2, target))) - self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduce=False), - loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduce=False)) + x, y, z, reduction='none'), (input1, input2, target))) + self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduction='none'), + loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduction='none')) def test_cosine_embedding_loss_margin_no_reduce(self): input1 = torch.randn(15, 10, requires_grad=True) input2 = torch.randn(15, 10, requires_grad=True) target = torch.randn(15).sign() self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss( - x, y, z, margin=0.5, reduce=False), (input1, input2, target))) - self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduce=False), - loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, margin=0.5, reduce=False)) + x, y, z, margin=0.5, reduction='none'), (input1, input2, target))) + self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduction='none'), + loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, + margin=0.5, reduction='none')) def test_margin_ranking_loss_no_reduce(self): input1 = torch.tensor(torch.randn(15).mul(10), requires_grad=True) input2 = torch.tensor(torch.randn(15).mul(10), requires_grad=True) target = torch.randn(15).sign() self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss( - x, y, z, reduce=False), (input1, input2, target))) - self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduce=False), - loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduce=False)) + x, y, z, reduction='none'), (input1, input2, target))) + self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduction='none'), + loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduction='none')) def test_margin_ranking_loss_margin_no_reduce(self): input1 = torch.tensor(torch.randn(15).mul(10), requires_grad=True) input2 = torch.tensor(torch.randn(15).mul(10), requires_grad=True) target = torch.randn(15).sign() self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss( - x, y, z, margin=0.5, reduce=False), (input1, input2, target))) - self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduce=False), - loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduce=False)) + x, y, z, margin=0.5, reduction='none'), (input1, input2, target))) + self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduction='none'), + loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduction='none')) def test_triplet_margin_loss(self): input1 = torch.randn(5, 10, requires_grad=True) @@ -4614,39 +4614,18 @@ def test_triplet_margin_loss_no_reduce(self): input2 = torch.randn(5, 10, requires_grad=True) input3 = torch.randn(5, 10, requires_grad=True) self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( - x1, x2, x3, reduce=False), (input1, input2, input3))) - self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduce=False), - loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False)) + x1, x2, x3, reduction='none'), (input1, input2, input3))) + self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'), + loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduction='none')) def test_triplet_margin_loss_swap_no_reduce(self): input1 = torch.randn(5, 10, requires_grad=True) input2 = torch.randn(5, 10, requires_grad=True) input3 = torch.randn(5, 10, requires_grad=True) self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( - x1, x2, x3, swap=True, reduce=False), (input1, input2, input3))) - self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduce=False), - loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduce=False)) - - def test_loss_reduction_arg(self): - # NB: This is a sanity test to check that the new reduction arg works the same as size_average and reduce - # Remove this when size_average and reduce are deprecated and tests are ported to the new arg - input1 = torch.randn(5, 10, requires_grad=True) - input2 = torch.randn(5, 10, requires_grad=True) - input3 = torch.randn(5, 10, requires_grad=True) - self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( - x1, x2, x3, reduction='elementwise_mean'), (input1, input2, input3))) - self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='elementwise_mean'), - loss_reference_fns['TripletMarginLoss'](input1, input2, input3)) - - self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( - x1, x2, x3, reduction='sum'), (input1, input2, input3))) - self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='sum'), - loss_reference_fns['TripletMarginLoss'](input1, input2, input3, size_average=False)) - - self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( - x1, x2, x3, reduction='none'), (input1, input2, input3))) - self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'), - loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False)) + x1, x2, x3, swap=True, reduction='none'), (input1, input2, input3))) + self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'), + loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none')) def test_cosine_similarity(self): input1 = torch.randn(4, 4, requires_grad=True) @@ -5774,8 +5753,8 @@ def forward(self, *args): input_size=(2, 3, 5, 5), target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(), reference_fn=lambda i, t, m: - loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)), - check_no_size_average=True, + loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)), + check_sum_reduction=True, desc='2d' ), dict( @@ -5789,7 +5768,7 @@ def forward(self, *args): ), dict( module_name='NLLLoss', - constructor_args=(None, True, 1), + constructor_args=(None, None, 1), input_size=(2, 3, 5, 5), target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(), reference_fn=lambda i, t, m: @@ -5801,8 +5780,8 @@ def forward(self, *args): input_size=(2, 3, 5, 5, 2, 2), target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(), reference_fn=lambda i, t, m: - loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)), - check_no_size_average=True, + loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)), + check_sum_reduction=True, desc='higher_dim' ), dict( @@ -5810,8 +5789,8 @@ def forward(self, *args): input_size=(2, 3, 5), target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), reference_fn=lambda i, t, m: - loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)), - check_no_size_average=True, + loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)), + check_sum_reduction=True, desc='dim_is_3' ), dict( @@ -5822,7 +5801,7 @@ def forward(self, *args): ), dict( module_name='PoissonNLLLoss', - constructor_args=(False, True, True), + constructor_args=(False,), input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001), target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(), desc='full_loss', # with sterling approx @@ -5839,16 +5818,17 @@ def forward(self, *args): input_fn=lambda: torch.rand(()).log(), target_fn=lambda: torch.rand(()), reference_fn=lambda i, t, m: - kldivloss_reference(i, t, get_size_average(m), reduce=True), - check_no_size_average=True, + kldivloss_reference(i, t, get_reduction(m)), + check_sum_reduction=True, desc='scalar', ), dict( module_name='MSELoss', input_size=(), target_size=(), - reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1), - check_no_size_average=True, + reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / + (i.numel() if get_reduction(m) == 'elementwise_mean' else 1)), + check_sum_reduction=True, desc='scalar' ), dict( @@ -5857,7 +5837,7 @@ def forward(self, *args): input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.rand(()).gt(0).double(), reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() / - (i.numel() if get_size_average(m) else 1), + (i.numel() if get_reduction(m) == 'elementwise_mean' else 1), desc='scalar_weights', check_gradgrad=False, ), @@ -5867,15 +5847,15 @@ def forward(self, *args): input_size=(), target_fn=lambda: torch.randn(()).gt(0).double().mul_(2).sub(1), desc='scalar_margin', - check_no_size_average=True, + check_sum_reduction=True, ), dict( module_name='SmoothL1Loss', input_size=(), target_size=(), - check_no_size_average=True, + check_sum_reduction=True, reference_fn=lambda i, t, m: - smoothl1loss_reference(i, t, size_average=get_size_average(m)), + smoothl1loss_reference(i, t, reduction=get_reduction(m)), desc='scalar', ), dict( @@ -5884,9 +5864,9 @@ def forward(self, *args): input_fn=lambda: torch.randn(5, 10), target_fn=lambda: torch.rand(5, 10).mul(2).floor(), reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() / - (i.numel() if get_size_average(m) else 1), + (i.numel() if get_reduction(m) == 'elementwise_mean' else 1), desc='weights', - check_no_size_average=True, + check_sum_reduction=True, check_gradgrad=False, ), ] @@ -5897,7 +5877,7 @@ def poissonnllloss_no_reduce_test(): return dict( fullname='PoissonNLLLLoss_no_reduce', constructor=wrap_functional( - lambda i: F.poisson_nll_loss(i, t.type_as(i), reduce=False)), + lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(10, 10), pickle=False) @@ -5907,7 +5887,7 @@ def bceloss_no_reduce_test(): return dict( fullname='BCELoss_no_reduce', constructor=wrap_functional( - lambda i: F.binary_cross_entropy(i, t.type_as(i), reduce=False)), + lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()), check_gradgrad=False, @@ -5919,7 +5899,7 @@ def bceloss_no_reduce_scalar_test(): return dict( fullname='BCELoss_no_reduce_scalar', constructor=wrap_functional( - lambda i: F.binary_cross_entropy(i, t.type_as(i), reduce=False)), + lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()), check_gradgrad=False, @@ -5933,7 +5913,7 @@ def bceloss_weights_no_reduce_test(): fullname='BCELoss_weights_no_reduce', constructor=wrap_functional( lambda i: F.binary_cross_entropy(i, t.type_as(i), - weight=weights.type_as(i), reduce=False)), + weight=weights.type_as(i), reduction='none')), input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, check_gradgrad=False, @@ -5947,7 +5927,7 @@ def bceloss_weights_no_reduce_scalar_test(): fullname='BCELoss_weights_no_reduce_scalar', constructor=wrap_functional( lambda i: F.binary_cross_entropy(i, t.type_as(i), - weight=weights.type_as(i), reduce=False)), + weight=weights.type_as(i), reduction='none')), input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, check_gradgrad=False, @@ -5960,7 +5940,7 @@ def bce_with_logistic_no_reduce_test(): return dict( fullname='BCEWithLogitsLoss_no_reduce', constructor=wrap_functional( - lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)), + lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), check_gradgrad=False, @@ -5973,7 +5953,7 @@ def bce_with_logistic_no_reduce_scalar_test(): return dict( fullname='BCEWithLogitsLoss_no_reduce_scalar', constructor=wrap_functional( - lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)), + lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), check_gradgrad=False, @@ -5985,10 +5965,10 @@ def kldivloss_with_target_no_reduce_test(): return dict( fullname='KLDivLoss_with_target_no_reduce', constructor=wrap_functional( - lambda t: F.kl_div(i.type_as(t), t, reduce=False)), + lambda t: F.kl_div(i.type_as(t), t, reduction='none')), input_fn=lambda: torch.rand(10, 10), reference_fn=lambda t, _: - loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduce=False), + loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduction='none'), pickle=False) @@ -5997,10 +5977,10 @@ def kldivloss_no_reduce_test(): return dict( fullname='KLDivLoss_no_reduce', constructor=wrap_functional( - lambda i: F.kl_div(i, t.type_as(i), reduce=False)), + lambda i: F.kl_div(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(10, 10).log(), reference_fn=lambda i, _: - loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduce=False), + loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), pickle=False) @@ -6009,10 +5989,10 @@ def kldivloss_no_reduce_scalar_test(): return dict( fullname='KLDivLoss_no_reduce_scalar', constructor=wrap_functional( - lambda i: F.kl_div(i, t.type_as(i), reduce=False)), + lambda i: F.kl_div(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(()).log(), reference_fn=lambda i, _: - loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduce=False), + loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), pickle=False) @@ -6021,7 +6001,7 @@ def l1loss_no_reduce_test(): return dict( fullname='L1Loss_no_reduce', constructor=wrap_functional( - lambda i: F.l1_loss(i, t.type_as(i), reduce=False)), + lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(2, 3, 4), reference_fn=lambda i, m: (i - t.type_as(i)).abs(), pickle=False) @@ -6032,7 +6012,7 @@ def l1loss_no_reduce_scalar_test(): return dict( fullname='L1Loss_no_reduce_scalar', constructor=wrap_functional( - lambda i: F.l1_loss(i, t.type_as(i), reduce=False)), + lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(()), reference_fn=lambda i, m: (i - t.type_as(i)).abs(), pickle=False) @@ -6044,7 +6024,7 @@ def mseloss_no_reduce_test(): return dict( fullname='MSELoss_no_reduce', constructor=wrap_functional( - lambda i: F.mse_loss(i, target.type_as(i), reduce=False)), + lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), input_size=input_size, reference_fn=lambda i, m: (i - target).pow(2), pickle=False) @@ -6056,7 +6036,7 @@ def mseloss_no_reduce_scalar_test(): return dict( fullname='MSELoss_no_reduce_scalar', constructor=wrap_functional( - lambda i: F.mse_loss(i, target.type_as(i), reduce=False)), + lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), input_size=input_size, reference_fn=lambda i, m: (i - target).pow(2), pickle=False) @@ -6064,7 +6044,7 @@ def mseloss_no_reduce_scalar_test(): def nllloss_no_reduce_test(): t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) - kwargs = {'reduce': False} + kwargs = {'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce', constructor=wrap_functional( @@ -6077,7 +6057,7 @@ def nllloss_no_reduce_test(): def nllloss_no_reduce_ignore_index_test(): t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) - kwargs = {'ignore_index': 2, 'reduce': False} + kwargs = {'ignore_index': 2, 'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce_ignore_index', constructor=wrap_functional( @@ -6093,7 +6073,7 @@ def nllloss_no_reduce_weights_test(): weight = torch.rand(10) def kwargs(i): - return {'weight': weight.type_as(i), 'reduce': False} + return {'weight': weight.type_as(i), 'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce_weights', @@ -6110,7 +6090,7 @@ def nllloss_no_reduce_weights_ignore_index_test(): weight = torch.rand(10) def kwargs(i): - return {'weight': weight.type_as(i), 'reduce': False, + return {'weight': weight.type_as(i), 'reduction': 'none', 'ignore_index': 2} return dict( @@ -6128,7 +6108,7 @@ def nllloss_no_reduce_weights_ignore_index_neg_test(): weight = torch.rand(10) def kwargs(i): - return {'weight': weight.type_as(i), 'reduce': False, + return {'weight': weight.type_as(i), 'reduction': 'none', 'ignore_index': -1} return dict( @@ -6143,7 +6123,7 @@ def kwargs(i): def nllloss2d_no_reduce_test(): t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) - kwargs = {'reduce': False} + kwargs = {'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce', constructor=wrap_functional( @@ -6156,7 +6136,7 @@ def nllloss2d_no_reduce_test(): def nllloss2d_no_reduce_ignore_index_test(): t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) - kwargs = {'ignore_index': 1, 'reduce': False} + kwargs = {'ignore_index': 1, 'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce_ignore_index', constructor=wrap_functional( @@ -6172,7 +6152,7 @@ def nllloss2d_no_reduce_weights_test(): weight = torch.rand(3) def kwargs(i): - return {'weight': weight.type_as(i), 'reduce': False} + return {'weight': weight.type_as(i), 'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce_weights', @@ -6186,7 +6166,7 @@ def kwargs(i): def nlllossNd_no_reduce_test(): t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) - kwargs = {'reduce': False} + kwargs = {'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce', constructor=wrap_functional( @@ -6199,7 +6179,7 @@ def nlllossNd_no_reduce_test(): def nlllossNd_no_reduce_ignore_index_test(): t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) - kwargs = {'ignore_index': 1, 'reduce': False} + kwargs = {'ignore_index': 1, 'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce_ignore_index', constructor=wrap_functional( @@ -6215,7 +6195,7 @@ def nlllossNd_no_reduce_weights_test(): weight = torch.rand(3) def kwargs(i): - return {'weight': weight.type_as(i), 'reduce': False} + return {'weight': weight.type_as(i), 'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce_weights', @@ -6232,10 +6212,10 @@ def smoothl1loss_no_reduce_test(): return dict( fullname='SmoothL1Loss_no_reduce', constructor=wrap_functional( - lambda i: F.smooth_l1_loss(i, t.type_as(i), reduce=False)), + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(2, 3, 4), reference_fn=lambda i, _: - loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduce=False), + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), pickle=False) @@ -6244,10 +6224,10 @@ def smoothl1loss_no_reduce_scalar_test(): return dict( fullname='SmoothL1Loss_no_reduce_scalar', constructor=wrap_functional( - lambda i: F.smooth_l1_loss(i, t.type_as(i), reduce=False)), + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(()), reference_fn=lambda i, _: - loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduce=False), + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), pickle=False) @@ -6256,11 +6236,11 @@ def multilabelmarginloss_1d_no_reduce_test(): return dict( fullname='MultiLabelMarginLoss_1d_no_reduce', constructor=wrap_functional( - lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)), + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), input_fn=lambda: torch.randn(10), reference_fn=lambda i, _: - loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False), - check_no_size_average=True, + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6270,11 +6250,11 @@ def multilabelmarginloss_index_neg_test(): return dict( fullname='MultiLabelMarginLoss_index_neg', constructor=wrap_functional( - lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)), + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), input_fn=lambda: torch.randn(5, 10), reference_fn=lambda i, _: - loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False), - check_no_size_average=True, + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6284,11 +6264,11 @@ def multilabelmarginloss_no_reduce_test(): return dict( fullname='MultiLabelMarginLoss_no_reduce', constructor=wrap_functional( - lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)), + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), input_fn=lambda: torch.randn(5, 10), reference_fn=lambda i, _: - loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False), - check_no_size_average=True, + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6298,11 +6278,11 @@ def hingeembeddingloss_no_reduce_test(): return dict( fullname='HingeEmbeddingLoss_no_reduce', constructor=wrap_functional( - lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduce=False)), + lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(10), reference_fn=lambda i, _: - loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduce=False), - check_no_size_average=True, + loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'), + check_sum_reduction=True, pickle=False) @@ -6311,11 +6291,11 @@ def hingeembeddingloss_margin_no_reduce_test(): return dict( fullname='HingeEmbeddingLoss_margin_no_reduce', constructor=wrap_functional( - lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduce=False)), + lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')), input_fn=lambda: torch.randn(10), reference_fn=lambda i, _: - loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduce=False), - check_no_size_average=True, + loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'), + check_sum_reduction=True, pickle=False) @@ -6324,11 +6304,10 @@ def softmarginloss_no_reduce_test(): return dict( fullname='SoftMarginLoss_no_reduce', constructor=wrap_functional( - lambda i: F.soft_margin_loss(i, t.type_as(i), reduce=False)), + lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(5, 5), reference_fn=lambda i, _: - loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduce=False), - check_no_size_average=True, + loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'), pickle=False) @@ -6337,11 +6316,9 @@ def multilabelsoftmarginloss_no_reduce_test(): return dict( fullname='MultiLabelSoftMarginLoss_no_reduce', constructor=wrap_functional( - lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduce=False)), + lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(5, 10), - reference_fn=lambda i, m: (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) / - (i.numel() if get_size_average(m) else 1)), - check_no_size_average=True, + reference_fn=lambda i, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()), check_gradgrad=False, pickle=False) @@ -6353,11 +6330,10 @@ def multilabelsoftmarginloss_weights_no_reduce_test(): fullname='MultiLabelSoftMarginLoss_weights_no_reduce', constructor=wrap_functional( lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), - weight=weights.type_as(i), reduce=False)), + weight=weights.type_as(i), reduction='none')), input_fn=lambda: torch.randn(5, 10), - reference_fn=lambda i, m: (-((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights) / - (i.numel() if get_size_average(m) else 1)), - check_no_size_average=True, + reference_fn=lambda i, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6367,11 +6343,11 @@ def multimarginloss_no_reduce_test(): return dict( fullname='MultiMarginLoss_no_reduce', constructor=wrap_functional( - lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduce=False)), + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), input_fn=lambda: torch.randn(5, 10), reference_fn=lambda i, _: - loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduce=False), - check_no_size_average=True, + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6381,11 +6357,11 @@ def multimarginloss_1d_no_reduce_test(): return dict( fullname='MultiMarginLoss_1d_no_reduce', constructor=wrap_functional( - lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduce=False)), + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), input_fn=lambda: torch.randn(10), reference_fn=lambda i, _: - loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduce=False), - check_no_size_average=True, + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6395,11 +6371,11 @@ def multimarginloss_p_no_reduce_test(): return dict( fullname='MultiMarginLoss_p_no_reduce', constructor=wrap_functional( - lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduce=False)), + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')), input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2), reference_fn=lambda i, _: - loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduce=False), - check_no_size_average=True, + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6409,12 +6385,12 @@ def multimarginloss_margin_no_reduce_test(): return dict( fullname='MultiMarginLoss_margin_no_reduce', constructor=wrap_functional( - lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduce=False)), + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')), input_fn=lambda: torch.randn(5, 10), reference_fn=lambda i, _: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), - margin=0.5, reduce=False), - check_no_size_average=True, + margin=0.5, reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -6426,12 +6402,12 @@ def multimarginloss_weights_no_reduce_test(): fullname='MultiMarginLoss_weights_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i), - reduce=False)), + reduction='none')), input_fn=lambda: torch.randn(5, 10), reference_fn=lambda i, _: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), - weight=weights, reduce=False), - check_no_size_average=True, + weight=weights, reduction='none'), + check_sum_reduction=True, check_gradgrad=False, pickle=False) @@ -7734,18 +7710,18 @@ def eval_constructor(*args, **kwargs): test = NewCriterionTest(**test_params) decorator = test_params.pop('decorator', None) add_test(test, decorator) - if 'check_no_size_average' in test_params: + if 'check_sum_reduction' in test_params: desc = test_params.get('desc', None) - test_params['desc'] = 'no_size_average' if desc is None else desc + '_no_size_average' + test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction' - def gen_no_size_average_constructor(constructor): - def no_size_average_constructor(*args, **kwargs): - cons = constructor(*args, size_average=False, **kwargs) + def gen_sum_reduction_constructor(constructor): + def sum_reduction_constructor(*args, **kwargs): + cons = constructor(*args, reduction='sum', **kwargs) return cons - no_size_average_constructor.__name__ = constructor.__name__ - return no_size_average_constructor + sum_reduction_constructor.__name__ = constructor.__name__ + return sum_reduction_constructor - test_params['constructor'] = gen_no_size_average_constructor(test_params['constructor']) + test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor']) test = NewCriterionTest(**test_params) add_test(test, decorator) diff --git a/third_party/onnx b/third_party/onnx index 7537916872738f..72a358d1a585f3 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 7537916872738fd47c12cebeb683f9e45f8066c2 +Subproject commit 72a358d1a585f32f186f7b585e9b9ce31827364b diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 885abaa9dd699f..74f978becd5bf5 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -19,14 +19,30 @@ #include "copy_utils.h" #include "DynamicTypes.h" +#ifdef USE_CUDA +#include +#endif + #include "generic/Storage.cpp" #include #include "generic/Storage.cpp" #include +// NB: If you ever divest libtorch of USE_CUDA, you'll have to virtualize +// the CUDA call. template<> void THPPointer::free() { - if (ptr) - THStorage_free(ptr); + if (ptr) { + if (ptr->backend == at::kCPU) { + THStorage_free(ptr); + } else { + AT_ASSERT(ptr->backend == at::kCUDA); +#ifdef USE_CUDA + THCStorage_free(at::globalContext().lazyInitCUDA(), ptr); +#else + AT_ERROR("Cannot free THCStorage when not built with CUDA"); +#endif + } + } } diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index 0b662bb070cb5c..bc6fbacff957b4 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -26,14 +26,11 @@ class BatchNormImpl : public torch::nn::Cloneable { Tensor forward(Tensor input); Tensor pure_forward(Tensor input, Tensor mean, Tensor variance); - const BatchNormOptions& options() const noexcept; - - private: - BatchNormOptions options_; - Tensor weight_; - Tensor bias_; - Tensor running_mean_; - Tensor running_variance_; + BatchNormOptions options; + Tensor weight; + Tensor bias; + Tensor running_mean; + Tensor running_variance; }; TORCH_MODULE(BatchNorm); diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index f7a2dd13f81f97..386a1937de7f0d 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -35,12 +35,10 @@ class ConvImpl : public torch::nn::Cloneable { explicit ConvImpl(ConvOptions options); void reset() override; - const ConvOptions& options() const noexcept; - protected: - Tensor weight_; - Tensor bias_; - ConvOptions options_; + ConvOptions options; + Tensor weight; + Tensor bias; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h index 5a5762e7583554..876b042550b31b 100644 --- a/torch/csrc/api/include/torch/nn/modules/dropout.h +++ b/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -18,16 +18,13 @@ namespace detail { template class DropoutImplBase : public torch::nn::Cloneable { public: - explicit DropoutImplBase(DropoutOptions options); + explicit DropoutImplBase(DropoutOptions options_); void reset() override; Tensor forward(Tensor input); - const DropoutOptions& options() const noexcept; - - protected: virtual Tensor noise_mask(Tensor input) const = 0; - DropoutOptions options_; + DropoutOptions options; }; } // namespace detail diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h index c345ca4a5bc4f9..861d75220606f6 100644 --- a/torch/csrc/api/include/torch/nn/modules/embedding.h +++ b/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -22,11 +22,9 @@ class EmbeddingImpl : public torch::nn::Cloneable { void reset() override; Tensor forward(Tensor); - const EmbeddingOptions& options() const noexcept; - private: - EmbeddingOptions options_; - Tensor table_; + EmbeddingOptions options; + Tensor table; }; TORCH_MODULE(Embedding); diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h index 612a2e48880ac6..40daaef77e37cd 100644 --- a/torch/csrc/api/include/torch/nn/modules/linear.h +++ b/torch/csrc/api/include/torch/nn/modules/linear.h @@ -23,12 +23,10 @@ class LinearImpl : public Cloneable { void reset() override; Tensor forward(Tensor); - const LinearOptions& options() const noexcept; - private: - Tensor weight_; - Tensor bias_; - LinearOptions options_; + LinearOptions options; + Tensor weight; + Tensor bias; }; TORCH_MODULE(Linear); diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index 431fb4673ee34e..bf3ffb863352be 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -41,7 +41,7 @@ class RNNImplBase : public torch::nn::Cloneable { enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }; RNNImplBase( - RNNOptionsBase options, + RNNOptionsBase options_, at::optional cudnn_mode = at::nullopt, int64_t number_of_gates = 1, bool has_cell_state = false); @@ -60,8 +60,20 @@ class RNNImplBase : public torch::nn::Cloneable { /// Recursively moves all parameters to the given device. void to(torch::Device device, bool non_blocking = false) override; + /// Fills the internal flattened parameter buffers passed to cuDNN. Call this + /// method if you mess around with the variable storages and want to use + /// cuDNN. void flatten_parameters_for_cudnn(); + RNNOptionsBase options; + + std::vector w_ih; + std::vector w_hh; + std::vector b_ih; + std::vector b_hh; + + Dropout dropout; + protected: virtual Tensor cell_forward(Tensor input, Tensor state, int64_t layer) = 0; @@ -72,17 +84,9 @@ class RNNImplBase : public torch::nn::Cloneable { bool use_cudnn(Tensor sample) const; Tensor create_dropout_state(Tensor input) const; - RNNOptionsBase options_; - - std::vector ihw_; - std::vector ihb_; - std::vector hhw_; - std::vector hhb_; - int64_t number_of_gates_; bool has_cell_state_; at::optional cudnn_mode_; - Dropout dropout_module_; // This is copied from pytorch, to determine whether weights are flat for the // fast CUDNN route. Otherwise, we have to use non flattened weights, which @@ -119,12 +123,10 @@ class RNNImpl : public detail::RNNImplBase { public: explicit RNNImpl(RNNOptions options); - const RNNOptions& options() const noexcept; + RNNOptions options; private: Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override; - - RNNOptions options_; std::function activation_function_; }; @@ -138,8 +140,6 @@ class LSTMImpl : public detail::RNNImplBase { public: explicit LSTMImpl(LSTMOptions options); - const LSTMOptions& options() const noexcept; - private: Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override; }; @@ -154,8 +154,6 @@ class GRUImpl : public detail::RNNImplBase { public: explicit GRUImpl(GRUOptions options); - const GRUOptions& options() const noexcept; - private: Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override; }; diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h index 8bde673950261f..3ac73d98a0e40c 100644 --- a/torch/csrc/api/include/torch/nn/modules/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/sequential.h @@ -23,6 +23,7 @@ namespace nn { class Sequential : public Cloneable { public: using Iterator = std::vector>::iterator; + using ConstIterator = std::vector>::const_iterator; /// Constructs the `Sequential` from a pack of modules. Each module can either /// be a plain value (e.g. `Linear`) or a boxed value (e.g. @@ -80,9 +81,7 @@ class Sequential : public Cloneable { static_assert( torch::detail::has_forward::value, "Can only add modules with a forward() method to Sequential"); - modules_.push_back(std::make_shared(std::move(module_ptr))); - const auto index = modules_.size() - 1; - register_module(std::to_string(index), modules_[index]->ptr()); + push_back(std::make_shared(std::move(module_ptr))); } /// Adds a new `Module` to the `Sequential` container, moving or copying it @@ -105,15 +104,37 @@ class Sequential : public Cloneable { push_back(module_holder.get()); } + /// Adds a type-erased `AnyModule` to the `Sequential`. + void push_back(std::shared_ptr any_module) { + modules_.push_back(std::move(any_module)); + const auto index = modules_.size() - 1; + register_module(std::to_string(index), modules_[index]->ptr()); + } + + /// Iterates over the container and calls `push_back()` on each iterated + /// value. + template + void extend(const Container& container) { + for (const auto& module : container) { + push_back(module); + } + } + /// Returns an iterator to the start of the `Sequential`. Iterator begin() { return modules_.begin(); } + ConstIterator begin() const { + return modules_.begin(); + } /// Returns an iterator to the end of the `Sequential`. Iterator end() { return modules_.end(); } + ConstIterator end() const { + return modules_.end(); + } /// Attempts to return the module at the given index as the requested type. /// Throws an exception if the index is out of bounds or the types do not diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp index b54b9554ad5af6..708c0e26cfddc9 100644 --- a/torch/csrc/api/src/nn/modules/batchnorm.cpp +++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp @@ -14,22 +14,22 @@ namespace nn { BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {} BatchNormImpl::BatchNormImpl(BatchNormOptions options) - : options_(std::move(options)) { + : options(std::move(options)) { reset(); } void BatchNormImpl::reset() { - if (options_.affine_) { - weight_ = register_parameter( - "weight", torch::empty({options_.features_}).uniform_()); - bias_ = register_parameter("bias", torch::zeros({options_.features_})); + if (options.affine_) { + weight = register_parameter( + "weight", torch::empty({options.features_}).uniform_()); + bias = register_parameter("bias", torch::zeros({options.features_})); } - if (options_.stateful_) { - running_mean_ = - register_buffer("running_mean", torch::zeros({options_.features_})); - running_variance_ = - register_buffer("running_variance", torch::ones({options_.features_})); + if (options.stateful_) { + running_mean = + register_buffer("running_mean", torch::zeros({options.features_})); + running_variance = + register_buffer("running_variance", torch::ones({options.features_})); } } @@ -38,8 +38,9 @@ Tensor BatchNormImpl::forward(Tensor input) { } Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) { - auto& running_mean = options_.stateful_ ? running_mean_ : mean; - auto& running_variance = options_.stateful_ ? running_variance_ : variance; + auto& running_mean = options.stateful_ ? this->running_mean : mean; + auto& running_variance = + options.stateful_ ? this->running_variance : variance; if (is_training()) { const auto num_channels = input.dim() > 1 ? input.size(1) : 1; @@ -50,19 +51,15 @@ Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) { return torch::batch_norm( input, - weight_, - bias_, + weight, + bias, running_mean, running_variance, is_training(), - options_.momentum_, - options_.eps_, + options.momentum_, + options.eps_, torch::cuda::cudnn_is_available()); } -const BatchNormOptions& BatchNormImpl::options() const noexcept { - return options_; -} - } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index 9f86bd3ee858ab..85f7fe8a38fafa 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -21,43 +21,43 @@ ConvOptions::ConvOptions( template ConvImpl::ConvImpl(ConvOptions options) - : options_(std::move(options)) { + : options(std::move(options)) { reset(); } template void ConvImpl::reset() { - if (!options_.transposed_) { - for (auto pad : *options_.output_padding_) { + if (!options.transposed_) { + for (auto pad : *options.output_padding_) { AT_CHECK( pad == 0, "Only transposed convolutions support output padding!"); } } std::vector weights_size; - if (options_.transposed_) { - weights_size.push_back(options_.input_channels_); - weights_size.push_back(options_.output_channels_ / options_.groups_); + if (options.transposed_) { + weights_size.push_back(options.input_channels_); + weights_size.push_back(options.output_channels_ / options.groups_); } else { - weights_size.push_back(options_.output_channels_); - weights_size.push_back(options_.input_channels_ / options_.groups_); + weights_size.push_back(options.output_channels_); + weights_size.push_back(options.input_channels_ / options.groups_); } weights_size.insert( weights_size.end(), - options_.kernel_size_->begin(), - options_.kernel_size_->end()); - AT_ASSERT(weights_size.size() == 2 + options_.kernel_size_->size()); - - weight_ = this->register_parameter("weight", torch::empty(weights_size)); - if (options_.with_bias_) { - bias_ = this->register_parameter( - "bias", torch::empty(options_.output_channels_)); + options.kernel_size_->begin(), + options.kernel_size_->end()); + AT_ASSERT(weights_size.size() == 2 + options.kernel_size_->size()); + + weight = this->register_parameter("weight", torch::empty(weights_size)); + if (options.with_bias_) { + bias = this->register_parameter( + "bias", torch::empty(options.output_channels_)); } const auto number_of_features = std::accumulate( - options_.kernel_size_->begin(), - options_.kernel_size_->end(), - options_.input_channels_, + options.kernel_size_->begin(), + options.kernel_size_->end(), + options.input_channels_, std::multiplies{}); const auto stdv = 1.0 / std::sqrt(number_of_features); for (auto& p : this->parameters()) { @@ -65,81 +65,76 @@ void ConvImpl::reset() { } } -template -const ConvOptions& ConvImpl::options() const noexcept { - return options_; -} - Tensor Conv1dImpl::forward(Tensor input) { AT_ASSERT(input.ndimension() == 3); - if (options_.transposed_) { + if (options.transposed_) { return torch::conv_transpose1d( input, - weight_, - bias_, - options_.stride_, - options_.padding_, - options_.output_padding_, - options_.groups_, - options_.dilation_); + weight, + bias, + options.stride_, + options.padding_, + options.output_padding_, + options.groups_, + options.dilation_); } return torch::conv1d( input, - weight_, - bias_, - options_.stride_, - options_.padding_, - options_.dilation_, - options_.groups_); + weight, + bias, + options.stride_, + options.padding_, + options.dilation_, + options.groups_); } Tensor Conv2dImpl::forward(Tensor input) { AT_ASSERT(input.ndimension() == 4); - if (options_.transposed_) { + if (options.transposed_) { return torch::conv_transpose2d( input, - weight_, - bias_, - options_.stride_, - options_.padding_, - options_.output_padding_, - options_.groups_, - options_.dilation_); + weight, + bias, + options.stride_, + options.padding_, + options.output_padding_, + options.groups_, + options.dilation_); } return torch::conv2d( input, - weight_, - bias_, - options_.stride_, - options_.padding_, - options_.dilation_, - options_.groups_); + weight, + bias, + options.stride_, + options.padding_, + options.dilation_, + options.groups_); } Tensor Conv3dImpl::forward(Tensor input) { AT_ASSERT(input.ndimension() == 5); - if (options_.transposed_) { + if (options.transposed_) { return torch::conv_transpose3d( input, - weight_, - bias_, - options_.stride_, - options_.padding_, - options_.output_padding_, - options_.groups_, - options_.dilation_); + weight, + bias, + options.stride_, + options.padding_, + options.output_padding_, + options.groups_, + options.dilation_); } else { return torch::conv3d( input, - weight_, - bias_, - options_.stride_, - options_.padding_, - options_.dilation_, - options_.groups_); + weight, + bias, + options.stride_, + options.padding_, + options.dilation_, + options.groups_); } } diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp index bbc19e3c107fe2..3a2e1b18617a35 100644 --- a/torch/csrc/api/src/nn/modules/dropout.cpp +++ b/torch/csrc/api/src/nn/modules/dropout.cpp @@ -11,10 +11,10 @@ namespace torch { namespace nn { namespace detail { template -DropoutImplBase::DropoutImplBase(DropoutOptions options) - : options_(options) { - AT_CHECK(options_.rate_ >= 0, "Dropout rate must not be less than zero"); - AT_CHECK(options_.rate_ <= 1, "Dropout rate must not be greater than one"); +DropoutImplBase::DropoutImplBase(DropoutOptions options_) + : options(options_) { + AT_CHECK(options.rate_ >= 0, "Dropout rate must not be less than zero"); + AT_CHECK(options.rate_ <= 1, "Dropout rate must not be greater than one"); } template @@ -22,22 +22,17 @@ void DropoutImplBase::reset() {} template Tensor DropoutImplBase::forward(Tensor input) { - if (options_.rate_ == 0 || !this->is_training()) { + if (options.rate_ == 0 || !this->is_training()) { return input; } - auto scale = 1.0f / (1.0f - options_.rate_); - auto boolean_mask = noise_mask(input).uniform_(0, 1) > options_.rate_; + auto scale = 1.0f / (1.0f - options.rate_); + auto boolean_mask = noise_mask(input).uniform_(0, 1) > options.rate_; auto noise = boolean_mask.to(input.dtype()).mul_(scale); return input * noise; } -template -const DropoutOptions& DropoutImplBase::options() const noexcept { - return options_; -} - template class DropoutImplBase; template class DropoutImplBase; } // namespace detail diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index e3570ad02c0aa1..54f6f314e05311 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -13,23 +13,18 @@ EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension) : count_(count), dimension_(dimension) {} EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) - : options_(std::move(options)) { + : options(std::move(options)) { reset(); } void EmbeddingImpl::reset() { - table_ = register_parameter( - "table", torch::empty({options_.count_, options_.dimension_})); - table_.data().normal_(0, 1); + table = register_parameter( + "table", torch::empty({options.count_, options.dimension_})); + table.data().normal_(0, 1); } Tensor EmbeddingImpl::forward(Tensor input) { - return torch::embedding(table_, /*indices=*/input); + return torch::embedding(table, /*indices=*/input); } - -const EmbeddingOptions& EmbeddingImpl::options() const noexcept { - return options_; -} - } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index fcda1ec01bd936..dcca7c062901dc 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -9,39 +9,35 @@ namespace torch { namespace nn { LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {} -LinearImpl::LinearImpl(LinearOptions options) : options_(std::move(options)) { +LinearImpl::LinearImpl(LinearOptions options) : options(std::move(options)) { reset(); } void LinearImpl::reset() { - weight_ = - register_parameter("weight", torch::empty({options_.out_, options_.in_})); - if (options_.with_bias_) { - bias_ = register_parameter("bias", torch::empty(options_.out_)); + weight = + register_parameter("weight", torch::empty({options.out_, options.in_})); + if (options.with_bias_) { + bias = register_parameter("bias", torch::empty(options.out_)); } - const auto stdv = 1.0 / std::sqrt(weight_.size(1)); + const auto stdv = 1.0 / std::sqrt(weight.size(1)); for (auto& p : parameters()) { p->data().uniform_(-stdv, stdv); } } Tensor LinearImpl::forward(Tensor input) { - if (input.ndimension() == 2 && options_.with_bias_) { + if (input.ndimension() == 2 && options.with_bias_) { // Fused op is marginally faster - AT_ASSERT(input.size(1) == weight_.size(1)); - return {torch::addmm(bias_, input, weight_.t())}; + AT_ASSERT(input.size(1) == weight.size(1)); + return {torch::addmm(bias, input, weight.t())}; } - auto output = input.matmul(weight_.t()); - if (options_.with_bias_) { - output += bias_; + auto output = input.matmul(weight.t()); + if (options.with_bias_) { + output += bias; } return output; } - -const LinearOptions& LinearImpl::options() const noexcept { - return options_; -} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index ff710f0011c72d..095e942e06c90c 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -47,11 +47,11 @@ RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size) template RNNImplBase::RNNImplBase( - RNNOptionsBase options, + RNNOptionsBase options_, at::optional cudnn_mode, int64_t number_of_gates, bool has_cell_state) - : options_(options), + : options(options_), number_of_gates_(number_of_gates), has_cell_state_(has_cell_state), cudnn_mode_(cudnn_mode) { @@ -60,36 +60,36 @@ RNNImplBase::RNNImplBase( template void RNNImplBase::reset() { - if (options_.dropout_ > 0.0) { - dropout_module_ = Dropout(options_.dropout_); + if (options.dropout_ > 0.0) { + dropout = Dropout(options.dropout_); } - ihw_.resize(options_.layers_); - hhw_.resize(options_.layers_); - ihb_.resize(options_.layers_); - hhb_.resize(options_.layers_); + w_ih.resize(options.layers_); + w_hh.resize(options.layers_); + b_ih.resize(options.layers_); + b_hh.resize(options.layers_); - const int64_t gate_size = options_.hidden_size_ * number_of_gates_; + const int64_t gate_size = options.hidden_size_ * number_of_gates_; - for (int64_t layer = 0; layer < options_.layers_; ++layer) { + for (int64_t layer = 0; layer < options.layers_; ++layer) { const int64_t input_size = - (layer == 0) ? options_.input_size_ : options_.hidden_size_; - ihw_[layer] = this->register_parameter( + (layer == 0) ? options.input_size_ : options.hidden_size_; + w_ih[layer] = this->register_parameter( "weight_ih_l" + std::to_string(layer), torch::empty({gate_size, input_size})); - hhw_[layer] = this->register_parameter( + w_hh[layer] = this->register_parameter( "weight_hh_l" + std::to_string(layer), - torch::empty({gate_size, options_.hidden_size_})); + torch::empty({gate_size, options.hidden_size_})); - if (options_.with_bias_) { - ihb_[layer] = this->register_parameter( + if (options.with_bias_) { + b_ih[layer] = this->register_parameter( "bias_ih_l" + std::to_string(layer), torch::empty({gate_size})); - hhb_[layer] = this->register_parameter( + b_hh[layer] = this->register_parameter( "bias_hh_l" + std::to_string(layer), torch::empty({gate_size})); } } - const auto stdv = 1.0 / std::sqrt(options_.hidden_size_); + const auto stdv = 1.0 / std::sqrt(options.hidden_size_); for (auto& p : this->parameters()) { p->data().uniform_(-stdv, stdv); } @@ -107,12 +107,12 @@ RNNOutput RNNImplBase::forward(Tensor input, Tensor state) { template std::vector RNNImplBase::flat_weights() const { std::vector flat; - for (int64_t layer = 0; layer < options_.layers_; layer++) { - flat.push_back(ihw_[layer]); - flat.push_back(hhw_[layer]); - if (options_.with_bias_) { - flat.push_back(ihb_[layer]); - flat.push_back(hhb_[layer]); + for (int64_t layer = 0; layer < options.layers_; layer++) { + flat.push_back(w_ih[layer]); + flat.push_back(w_hh[layer]); + if (options.with_bias_) { + flat.push_back(b_ih[layer]); + flat.push_back(b_hh[layer]); } } return flat; @@ -128,11 +128,11 @@ template Tensor RNNImplBase::create_dropout_state(Tensor input) const { static const int64_t dropout_seed = torch::ones({}, torch::kInt64).random_().toCLong(); - if (options_.dropout_ > 0) { + if (options.dropout_ > 0) { torch::DeviceGuard guard(input.device()); return torch::_cudnn_init_dropout_state( input.type().toScalarType(torch::kUInt8), - options_.dropout_, + options.dropout_, this->is_training(), dropout_seed); } @@ -144,16 +144,16 @@ RNNOutput RNNImplBase::autograd_forward(Tensor input, Tensor state) { std::vector new_state; auto has_hidden = state.defined(); auto layer_dimension = has_hidden ? state.ndimension() - 3 : -1; - for (int64_t layer = 0; layer < options_.layers_; layer++) { + for (int64_t layer = 0; layer < options.layers_; layer++) { new_state.push_back( has_hidden ? state.select(layer_dimension, layer) : Tensor()); } auto output = torch::zeros( - {input.size(0), input.size(1), options_.hidden_size_}, input.options()); + {input.size(0), input.size(1), options.hidden_size_}, input.options()); for (int64_t t = 0; t < input.size(0); t++) { auto x = input.select(0, t); - for (int64_t i = 0; i < options_.layers_; i++) { + for (int64_t i = 0; i < options.layers_; i++) { // cell_forward() returns a stacked tensor of one or more cell states. auto layer_output = cell_forward(x, new_state[i], i); // If there are multiple cell states, keep all. If there is only one, @@ -162,8 +162,8 @@ RNNOutput RNNImplBase::autograd_forward(Tensor input, Tensor state) { // x should always be the hidden cell state h, assumed to be the zero-th. x = layer_output[0]; output.select(0, t).copy_(x); - if (options_.dropout_ > 0 && i != options_.layers_ - 1) { - x = dropout_module_->forward(x); + if (options.dropout_ > 0 && i != options.layers_ - 1) { + x = dropout->forward(x); } } } @@ -178,8 +178,8 @@ RNNOutput RNNImplBase::autograd_forward(Tensor input, Tensor state) { template void RNNImplBase::flatten_parameters_for_cudnn() { data_ptrs_.clear(); - const auto any_parameter = ihw_.at(0); - if (!use_cudnn(/*sample=*/ihw_.at(0))) { + const auto any_parameter = w_ih.at(0); + if (!use_cudnn(/*sample=*/w_ih.at(0))) { return; } std::unordered_set unique_data_ptrs; @@ -200,11 +200,11 @@ void RNNImplBase::flatten_parameters_for_cudnn() { NoGradGuard guard; flat_weights_ = torch::_cudnn_rnn_flatten_weight( TensorListView(flat_weights()), - /*weight_stride=*/options_.with_bias_ ? 4 : 2, - options_.input_size_, + /*weight_stride=*/options.with_bias_ ? 4 : 2, + options.input_size_, static_cast(*cudnn_mode_), - options_.hidden_size_, - options_.layers_, + options.hidden_size_, + options.layers_, /*batch_first=*/false, /*bidirectional=*/false); } @@ -225,11 +225,11 @@ RNNOutput RNNImplBase::CUDNN_forward(Tensor input, Tensor state) { } } else { hx = torch::zeros( - {options_.layers_, input.size(1), options_.hidden_size_}, + {options.layers_, input.size(1), options.hidden_size_}, input.options()); if (has_cell_state_) { cx = torch::zeros( - {options_.layers_, input.size(1), options_.hidden_size_}, + {options.layers_, input.size(1), options.hidden_size_}, input.options()); } } @@ -248,15 +248,15 @@ RNNOutput RNNImplBase::CUDNN_forward(Tensor input, Tensor state) { auto cudnn_output = torch::_cudnn_rnn( /*input=*/input, /*weight=*/TensorListView(flat_weights()), - /*weight_stride0=*/options_.with_bias_ ? 4 : 2, + /*weight_stride0=*/options.with_bias_ ? 4 : 2, /*weight_buf=*/flat_weights_, /*hx=*/hx, /*cx=*/cx, /*mode=*/static_cast(*cudnn_mode_), - /*hidden_size=*/options_.hidden_size_, - /*num_layers=*/options_.layers_, + /*hidden_size=*/options.hidden_size_, + /*num_layers=*/options.layers_, /*batch_first=*/false, - /*dropout=*/options_.dropout_, + /*dropout=*/options.dropout_, /*train=*/this->is_training(), /*bidirectional=*/false, /*batch_sizes=*/{}, @@ -318,8 +318,8 @@ RNNImpl::RNNImpl(RNNOptions options) .with_bias(options.with_bias_) .dropout(options.dropout_), /*cudnn_mode=*/static_cast(options.activation_)), - options_(options) { - switch (options_.activation_) { + options(options) { + switch (options.activation_) { case RNNActivation::ReLU: { activation_function_ = torch::relu; break; @@ -334,18 +334,14 @@ RNNImpl::RNNImpl(RNNOptions options) Tensor RNNImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { auto hx = state.defined() ? state - : torch::zeros({input.size(0), options_.hidden_size_}, input.options()); + : torch::zeros({input.size(0), options.hidden_size_}, input.options()); - auto h = linear(input, ihw_[layer], ihb_[layer]) + - linear(hx, hhw_[layer], hhb_[layer]); + auto h = linear(input, w_ih[layer], b_ih[layer]) + + linear(hx, w_hh[layer], b_hh[layer]); return torch::stack(activation_function_(h)); } -const RNNOptions& RNNImpl::options() const noexcept { - return options_; -} - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMImpl::LSTMImpl(LSTMOptions options) @@ -358,13 +354,12 @@ LSTMImpl::LSTMImpl(LSTMOptions options) Tensor LSTMImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { auto hid = state.defined() ? state - : torch::zeros( - {2, input.size(0), options_.hidden_size_}, input.options()); + : torch::zeros({2, input.size(0), options.hidden_size_}, input.options()); auto hx = hid[0]; auto cx = hid[1]; - auto gates = linear(input, ihw_[layer], ihb_[layer]) + - linear(hx, hhw_[layer], hhb_[layer]); + auto gates = linear(input, w_ih[layer], b_ih[layer]) + + linear(hx, w_hh[layer], b_hh[layer]); auto chunked = gates.chunk(4, 1); auto in_gate = chunked[0].sigmoid(); @@ -378,10 +373,6 @@ Tensor LSTMImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { return torch::stack(TensorListView{hy, cy}, 0); } -const LSTMOptions& LSTMImpl::options() const noexcept { - return options_; -} - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUImpl::GRUImpl(GRUOptions options) @@ -393,10 +384,10 @@ GRUImpl::GRUImpl(GRUOptions options) Tensor GRUImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { auto hx = state.defined() ? state - : torch::zeros({input.size(0), options_.hidden_size_}, input.options()); + : torch::zeros({input.size(0), options.hidden_size_}, input.options()); - auto gi = linear(input, ihw_[layer], ihb_[layer]); - auto gh = linear(input, hhw_[layer], hhb_[layer]); + auto gi = linear(input, w_ih[layer], b_ih[layer]); + auto gh = linear(input, w_hh[layer], b_hh[layer]); auto gic = gi.chunk(3, 1); auto ghc = gh.chunk(3, 1); @@ -407,9 +398,5 @@ Tensor GRUImpl::cell_forward(Tensor input, Tensor state, int64_t layer) { return torch::stack(TensorListView(hy)); } - -const GRUOptions& GRUImpl::options() const noexcept { - return options_; -} } // namespace nn } // namespace torch diff --git a/torch/csrc/cuda/Storage.cpp b/torch/csrc/cuda/Storage.cpp index 4e3d1b4ef3dd83..c552af4f11e837 100644 --- a/torch/csrc/cuda/Storage.cpp +++ b/torch/csrc/cuda/Storage.cpp @@ -17,9 +17,3 @@ #define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp" #include - -template<> -void THPPointer::free() { - if (ptr) - THCStorage_free(LIBRARY_STATE ptr); -} diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 0242534ab3679e..5284207988cf07 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -11,9 +11,9 @@ static PyObject * THPStorage_(sharedDecref)(THPStorage *self) #ifndef THC_GENERIC_FILE libshm_context *ctx = NULL; THWStorage *storage = self->cdata; - if (storage->allocator == &THManagedSharedAllocator) { + if (storage->allocatorVoidPtr == &THManagedSharedAllocator) { ctx = (libshm_context*)storage->allocatorContext; - } else if (storage->allocator == &THStorageWeakRefAllocator) { + } else if (storage->allocatorVoidPtr == &THStorageWeakRefAllocator) { auto allocator_obj = ((StorageWeakRefAllocator*)storage->allocatorContext); if (allocator_obj->allocator == &THManagedSharedAllocator) ctx = (libshm_context*)allocator_obj->allocatorContext; @@ -32,9 +32,9 @@ static PyObject * THPStorage_(sharedIncref)(THPStorage *self) #ifndef THC_GENERIC_FILE libshm_context *ctx = NULL; THWStorage *storage = self->cdata; - if (storage->allocator == &THManagedSharedAllocator) { + if (storage->allocatorVoidPtr == &THManagedSharedAllocator) { ctx = (libshm_context*)storage->allocatorContext; - } else if (storage->allocator == &THStorageWeakRefAllocator) { + } else if (storage->allocatorVoidPtr == &THStorageWeakRefAllocator) { auto allocator_obj = ((StorageWeakRefAllocator*)storage->allocatorContext); if (allocator_obj->allocator == &THManagedSharedAllocator) ctx = (libshm_context*)allocator_obj->allocatorContext; @@ -96,9 +96,9 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self) THWStorage *storage = self->cdata; libshm_context *ctx; // Storage is already in shared memory, just return a handle - if (storage->allocator == &THManagedSharedAllocator) { + if (storage->allocatorVoidPtr == &THManagedSharedAllocator) { ctx = (libshm_context*)storage->allocatorContext; - } else if (storage->allocator == &THStorageWeakRefAllocator) { + } else if (storage->allocatorVoidPtr == &THStorageWeakRefAllocator) { auto allocator_obj = ((StorageWeakRefAllocator*)storage->allocatorContext); ctx = (libshm_context*)allocator_obj->allocatorContext; } else { @@ -178,9 +178,9 @@ static PyObject * THPStorage_(shareFd)(THPStorage *self) THWStorage *storage = self->cdata; THMapAllocatorContext *ctx; // Storage is already in shared memory, just return a handle - if (storage->allocator == &THMapAllocator) { + if (storage->allocatorVoidPtr == &THMapAllocator) { ctx = (THMapAllocatorContext*)storage->allocatorContext; - } else if (storage->allocator == &THStorageWeakRefAllocator) { + } else if (storage->allocatorVoidPtr == &THStorageWeakRefAllocator) { auto allocator_obj = ((StorageWeakRefAllocator*)storage->allocatorContext); ctx = (THMapAllocatorContext*)allocator_obj->allocatorContext; } else { @@ -330,9 +330,9 @@ static PyObject * THPStorage_(weakRef)(THPStorage *self, PyObject *weak_ref_clas } bool hasWeakAllocator; #ifdef THC_GENERIC_FILE - hasWeakAllocator = storage->allocator == &THCStorageWeakRefAllocator; + hasWeakAllocator = storage->allocatorVoidPtr == &THCStorageWeakRefAllocator; #else - hasWeakAllocator = storage->allocator == &THStorageWeakRefAllocator; + hasWeakAllocator = storage->allocatorVoidPtr == &THStorageWeakRefAllocator; #endif if (hasWeakAllocator) { auto allocator_obj = ((StorageWeakRefAllocator*)storage->allocatorContext); @@ -346,12 +346,12 @@ static PyObject * THPStorage_(weakRef)(THPStorage *self, PyObject *weak_ref_clas if (!ref) return NULL; #ifdef THC_GENERIC_FILE storage->allocatorContext = new CudaStorageWeakRefAllocator( - ref.get(), storage->allocator, storage->allocatorContext); - storage->allocator = &THCStorageWeakRefAllocator; + ref.get(), static_cast(storage->allocatorVoidPtr), storage->allocatorContext); + storage->allocatorVoidPtr = &THCStorageWeakRefAllocator; #else storage->allocatorContext = new StorageWeakRefAllocator( - ref.get(), storage->allocator, storage->allocatorContext); - storage->allocator = &THStorageWeakRefAllocator; + ref.get(), static_cast(storage->allocatorVoidPtr), storage->allocatorContext); + storage->allocatorVoidPtr = &THStorageWeakRefAllocator; #endif return ref.release(); END_HANDLE_TH_ERRORS @@ -382,9 +382,9 @@ PyObject * THPStorage_(sharedFd)(THPStorage *self) THMapAllocatorContext *ctx = NULL; #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; - if (storage->allocator == &THMapAllocator) { + if (storage->allocatorVoidPtr == &THMapAllocator) { ctx = (THMapAllocatorContext*)storage->allocatorContext; - } else if (storage->allocator == &THStorageWeakRefAllocator) { + } else if (storage->allocatorVoidPtr == &THStorageWeakRefAllocator) { auto allocator_obj = ((StorageWeakRefAllocator*)storage->allocatorContext); if (allocator_obj->allocator == &THMapAllocator) { ctx = (THMapAllocatorContext*)allocator_obj->allocatorContext; @@ -416,7 +416,7 @@ PyObject * THPStorage_(isShared)(THPStorage *self) #ifdef THC_GENERIC_FILE Py_RETURN_TRUE; #else - void *allocator = self->cdata->allocator; + void *allocator = self->cdata->allocatorVoidPtr; if (allocator == &THMapAllocator || allocator == &THStorageWeakRefAllocator || allocator == &THManagedSharedAllocator) { diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 6852a71df9f349..9b241f6eba5bde 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -6,7 +6,7 @@ from torch.distributions import constraints from torch.distributions.utils import (_sum_rightmost, broadcast_all, lazy_property) -from torch.nn.functional import pad, sigmoid +from torch.nn.functional import pad __all__ = [ 'AbsTransform', @@ -341,7 +341,7 @@ def __eq__(self, other): return isinstance(other, SigmoidTransform) def _call(self, x): - return sigmoid(x) + return torch.sigmoid(x) def _inverse(self, y): return y.log() - (-y).log1p() @@ -483,7 +483,7 @@ def __eq__(self, other): def _call(self, x): offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) - z = sigmoid(x - offset.log()) + z = torch.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) return y @@ -497,7 +497,7 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) - z = sigmoid(x - offset.log()) + z = torch.sigmoid(x - offset.log()) detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1) return detJ diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index fab614578849b5..f0f1c17b596ba0 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -116,7 +116,7 @@ def logits_to_probs(logits, is_binary=False): the log probabilities (possibly unnormalized) of the events. """ if is_binary: - return F.sigmoid(logits) + return torch.sigmoid(logits) return F.softmax(logits, dim=-1) diff --git a/torch/legacy/nn/AbsCriterion.py b/torch/legacy/nn/AbsCriterion.py index 4a2ea6896dd9fc..66f7615205d187 100644 --- a/torch/legacy/nn/AbsCriterion.py +++ b/torch/legacy/nn/AbsCriterion.py @@ -18,7 +18,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -31,6 +31,6 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/ClassNLLCriterion.py b/torch/legacy/nn/ClassNLLCriterion.py index 50ddcfd28e90f5..33c28e5d21cb96 100644 --- a/torch/legacy/nn/ClassNLLCriterion.py +++ b/torch/legacy/nn/ClassNLLCriterion.py @@ -25,7 +25,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), self.weights, self.total_weight_tensor, self.ignore_index, @@ -44,7 +44,7 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), self.weights, self.total_weight_tensor, self.ignore_index, diff --git a/torch/legacy/nn/ClassSimplexCriterion.py b/torch/legacy/nn/ClassSimplexCriterion.py index b28ce672969b23..1de585147347e1 100644 --- a/torch/legacy/nn/ClassSimplexCriterion.py +++ b/torch/legacy/nn/ClassSimplexCriterion.py @@ -81,7 +81,7 @@ def updateOutput(self, input, target): input, self._target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -95,7 +95,7 @@ def updateGradInput(self, input, target): self._target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/DistKLDivCriterion.py b/torch/legacy/nn/DistKLDivCriterion.py index 8c18cf1e2d74a9..5aa175604a05ff 100644 --- a/torch/legacy/nn/DistKLDivCriterion.py +++ b/torch/legacy/nn/DistKLDivCriterion.py @@ -19,7 +19,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -33,6 +33,6 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/MSECriterion.py b/torch/legacy/nn/MSECriterion.py index 2422e07e4d87d3..2079d366c2ce6c 100644 --- a/torch/legacy/nn/MSECriterion.py +++ b/torch/legacy/nn/MSECriterion.py @@ -18,7 +18,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -32,6 +32,6 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/MultiLabelMarginCriterion.py b/torch/legacy/nn/MultiLabelMarginCriterion.py index 1de12bfeabf015..9ca2a233efdf99 100644 --- a/torch/legacy/nn/MultiLabelMarginCriterion.py +++ b/torch/legacy/nn/MultiLabelMarginCriterion.py @@ -21,7 +21,7 @@ def updateOutput(self, input, target): target, self.output_tensor, self.isTarget, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -36,6 +36,6 @@ def updateGradInput(self, input, target): implicit_gradOutput, self.gradInput, self.isTarget, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/MultiMarginCriterion.py b/torch/legacy/nn/MultiMarginCriterion.py index 26b9cff8dcfc45..cc9835c3395f99 100644 --- a/torch/legacy/nn/MultiMarginCriterion.py +++ b/torch/legacy/nn/MultiMarginCriterion.py @@ -26,7 +26,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), self.p, self.weights, self.margin, @@ -43,7 +43,7 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), self.p, self.weights, self.margin, diff --git a/torch/legacy/nn/SmoothL1Criterion.py b/torch/legacy/nn/SmoothL1Criterion.py index c02e7a2b85f7df..714d0b6ed0fe0b 100644 --- a/torch/legacy/nn/SmoothL1Criterion.py +++ b/torch/legacy/nn/SmoothL1Criterion.py @@ -18,7 +18,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -31,6 +31,6 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/SoftMarginCriterion.py b/torch/legacy/nn/SoftMarginCriterion.py index e56d8716ce66fe..4bfa37173ce013 100644 --- a/torch/legacy/nn/SoftMarginCriterion.py +++ b/torch/legacy/nn/SoftMarginCriterion.py @@ -18,7 +18,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -31,6 +31,6 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/legacy/nn/SpatialClassNLLCriterion.py b/torch/legacy/nn/SpatialClassNLLCriterion.py index 382cfea12defc8..8a7e15c8298149 100644 --- a/torch/legacy/nn/SpatialClassNLLCriterion.py +++ b/torch/legacy/nn/SpatialClassNLLCriterion.py @@ -23,7 +23,7 @@ def updateOutput(self, input, target): input, target, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), self.weights, self.total_weight_tensor, self.ignore_index, @@ -40,7 +40,7 @@ def updateGradInput(self, input, target): target, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), self.weights, self.total_weight_tensor, self.ignore_index, diff --git a/torch/legacy/nn/WeightedMSECriterion.py b/torch/legacy/nn/WeightedMSECriterion.py index 4e034395114e1c..2f0da29077d508 100644 --- a/torch/legacy/nn/WeightedMSECriterion.py +++ b/torch/legacy/nn/WeightedMSECriterion.py @@ -29,7 +29,7 @@ def updateOutput(self, input, target): input, self.buffer, self.output_tensor, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) self.output = self.output_tensor[0].item() return self.output @@ -50,6 +50,6 @@ def updateGradInput(self, input, target): self.buffer, implicit_gradOutput, self.gradInput, - _Reduction.legacy_get_enum(self.sizeAverage, True), + _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False), ) return self.gradInput diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index edbf941683c037..9be0a88e2221e5 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -2,6 +2,7 @@ #include "private/CUDAUtils.hpp" #include +#include #include #include diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py index 3c1002c8373a89..c7f5d10ccd4df3 100644 --- a/torch/nn/_functions/rnn.py +++ b/torch/nn/_functions/rnn.py @@ -2,6 +2,7 @@ from torch.autograd import NestedIOFunction import torch.backends.cudnn as cudnn from .. import functional as F +import torch from .thnn import rnnFusedPointwise as fusedBackend import itertools from functools import partial @@ -18,7 +19,7 @@ def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): - hy = F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) + hy = torch.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) return hy @@ -34,13 +35,13 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) - ingate = F.sigmoid(ingate) - forgetgate = F.sigmoid(forgetgate) - cellgate = F.tanh(cellgate) - outgate = F.sigmoid(outgate) + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) cy = (forgetgate * cx) + (ingate * cellgate) - hy = outgate * F.tanh(cy) + hy = outgate * torch.tanh(cy) return hy, cy @@ -58,9 +59,9 @@ def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): i_r, i_i, i_n = gi.chunk(3, 1) h_r, h_i, h_n = gh.chunk(3, 1) - resetgate = F.sigmoid(i_r + h_r) - inputgate = F.sigmoid(i_i + h_i) - newgate = F.tanh(i_n + resetgate * h_n) + resetgate = torch.sigmoid(i_r + h_r) + inputgate = torch.sigmoid(i_i + h_i) + newgate = torch.tanh(i_n + resetgate * h_n) hy = newgate + inputgate * (hidden - newgate) return hy diff --git a/torch/nn/functional.py b/torch/nn/functional.py index dae0c38a522f3d..a6c00b160f8857 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -31,8 +31,10 @@ def get_enum(reduction): # In order to support previous versions, accept boolean size_average and reduce # and convert them into the new constants for now + + # We use these functions in torch/legacy as well, in which case we'll silence the warning @staticmethod - def legacy_get_string(size_average, reduce): + def legacy_get_string(size_average, reduce, emit_warning=True): warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: @@ -41,18 +43,18 @@ def legacy_get_string(size_average, reduce): reduce = True if size_average and reduce: - warnings.warn(warning.format('elementwise_mean')) - return 'elementwise_mean' + ret = 'elementwise_mean' elif reduce: - warnings.warn(warning.format('sum')) - return 'sum' + ret = 'sum' else: - warnings.warn(warning.format('none')) - return 'none' + ret = 'none' + if emit_warning: + warnings.warn(warning.format(ret)) + return ret @staticmethod - def legacy_get_enum(size_average, reduce): - return _Reduction.get_enum(_Reduction.legacy_get_string(size_average, reduce)) + def legacy_get_enum(size_average, reduce, emit_warning=True): + return _Reduction.get_enum(_Reduction.legacy_get_string(size_average, reduce, emit_warning)) conv1d = _add_docstr(torch.conv1d, r""" @@ -1009,6 +1011,7 @@ def tanh(input): See :class:`~torch.nn.Tanh` for more details. """ + warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.") return input.tanh() @@ -1019,11 +1022,10 @@ def sigmoid(input): See :class:`~torch.nn.Sigmoid` for more details. """ + warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.") return input.sigmoid() -# etc. - def linear(input, weight, bias=None): r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. diff --git a/torch/utils/model_zoo.py b/torch/utils/model_zoo.py index c4c5526aae3028..dfeb83e6d80ff4 100644 --- a/torch/utils/model_zoo.py +++ b/torch/utils/model_zoo.py @@ -81,21 +81,24 @@ def _download_url_to_file(url, dst, hash_prefix, progress): f = tempfile.NamedTemporaryFile(delete=False) try: - sha256 = hashlib.sha256() + if hash_prefix is not None: + sha256 = hashlib.sha256() with tqdm(total=file_size, disable=not progress) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) - sha256.update(buffer) + if hash_prefix is not None: + sha256.update(buffer) pbar.update(len(buffer)) f.close() - digest = sha256.hexdigest() - if digest[:len(hash_prefix)] != hash_prefix: - raise RuntimeError('invalid hash value (expected "{}", got "{}")' - .format(hash_prefix, digest)) + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError('invalid hash value (expected "{}", got "{}")' + .format(hash_prefix, digest)) shutil.move(f.name, dst) finally: f.close()