Skip to content

Back out "[pytorch][PR] Revert "Move CreateContext to global registry (#11688)"" #12121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions aten/src/ATen/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
return storage_.device_type();
}

at::Device GetDevice() const {
return storage_.device();
}

/**
* The static context of a tensor intuitively represents the device
* type of a tensor; e.g., a CPU tensor is associated with the
Expand All @@ -376,18 +380,6 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
return ::caffe2::get_static_context(device_type());
}

/* @brief
* Create a context that has the same device_type
* as the tensor.
* Note that this doesn't support passing in argument
* TODO(jerryzh): move this to a global registry
* that can create context for us, and then eliminate
* this method.
*/
std::unique_ptr<at::BaseContext> CreateContext() const {
return GetStaticContext()->CreateContext();
}

/**
* @brief Copies the data from a source tensor, with a contex provided to
* carry out the underlying memcpy operation. This method respects
Expand Down Expand Up @@ -429,8 +421,12 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
// knows how to copy between CPU and that context
if (src.device_type() != ::at::DeviceType::CPU || device_type() == ::at::DeviceType::CPU) {
if (!context) {
src.CreateContext()->CopyBytesToDevice(
numel() * itemsize(), src.data(), raw_mutable_data(data_type_), device_type());
CreateContext(src.GetDevice())
->CopyBytesToDevice(
numel() * itemsize(),
src.data(),
raw_mutable_data(data_type_),
device_type());
} else {
CAFFE_ENFORCE(
context->device_type() == src.device_type(),
Expand All @@ -442,8 +438,11 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
// In case source context is CPU, and target context is non-CPU
// We'll have to create a Context from target and perform the
// copy using that context
CreateContext()->CopyBytesFromCPU(
numel() * itemsize(), src.data(), raw_mutable_data(data_type_));
CreateContext(GetDevice())
->CopyBytesFromCPU(
numel() * itemsize(),
src.data(),
raw_mutable_data(data_type_));
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/core/context_base.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
#include <ATen/core/context_base.h>

namespace at {

C10_DEFINE_TYPED_REGISTRY(
ContextRegistry,
at::DeviceType,
at::BaseContext,
std::unique_ptr,
at::Device);

} // namespace at

namespace caffe2 {

// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h
Expand Down
26 changes: 19 additions & 7 deletions aten/src/ATen/core/context_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
#include <memory>
#include <unordered_map>

#include <ATen/core/DeviceType.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Device.h>
#include <ATen/core/Error.h>
#include <ATen/core/UniqueVoidPtr.h>
#include <ATen/core/typeid.h>
#include <ATen/core/ATenGeneral.h>
#include <c10/util/Registry.h>

namespace caffe2 {
class Event;
Expand All @@ -31,11 +32,6 @@ class CAFFE2_API BaseStaticContext {

virtual std::pair<void*, DeleterFnPtr> New(size_t nbytes) const = 0;

virtual std::unique_ptr<BaseContext> CreateContext() = 0;

virtual std::unique_ptr<BaseContext> CreateContext(
const caffe2::DeviceOption&) = 0;

virtual DeviceType GetDeviceType() = 0;

/*
Expand Down Expand Up @@ -184,6 +180,22 @@ class CAFFE2_API BaseContext {
}
};

// Context constructor registry
C10_DECLARE_TYPED_REGISTRY(
ContextRegistry,
at::DeviceType,
at::BaseContext,
std::unique_ptr,
at::Device);

#define REGISTER_CONTEXT(type, ...) \
C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)

inline std::unique_ptr<at::BaseContext> CreateContext(
const at::Device& device) {
return at::ContextRegistry()->Create(device.type(), device);
}

} // namespace at

namespace caffe2 {
Expand Down
5 changes: 2 additions & 3 deletions caffe2/core/blob_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void TensorSerializer::Serialize(
const TensorProto::DataType data_type = TypeMetaToDataType(input.meta());
proto.set_data_type(data_type);
StoreDeviceDetail(input, &proto);
auto uniq_ptr = input.GetStaticContext()->CreateContext();
auto uniq_ptr = CreateContext(input.GetDevice());
// A lot of copypaste is error prone. Should we create a macro for this?
switch (data_type) {
case TensorProto_DataType_FLOAT:
Expand Down Expand Up @@ -371,8 +371,7 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
// We create a local context for deserializing. Since Caffe2 contexts are
// usually lightweight, this should not involve too much overhead.
auto uniq_ptr =
tensor->GetStaticContext()->CreateContext(proto.device_detail());
auto uniq_ptr = CreateContext(OptionToDevice(proto.device_detail()));
auto context = uniq_ptr.get();
context->SwitchToDevice(0);
vector<int64_t> dims;
Expand Down
4 changes: 4 additions & 0 deletions caffe2/core/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include <process.h>
#endif

namespace at {

REGISTER_CONTEXT(DeviceType::CPU, caffe2::CPUContext);
} // namespace at
namespace caffe2 {

uint32_t RandomNumberSeed() {
Expand Down
11 changes: 2 additions & 9 deletions caffe2/core/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class CAFFE2_API CPUContext final : public BaseContext {
: RandomNumberSeed()) {
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU);
}
explicit CPUContext(const at::Device& device)
: CPUContext(DeviceToOption(device)) {}

~CPUContext() noexcept override {}

Expand Down Expand Up @@ -192,15 +194,6 @@ class CAFFE2_API CPUStaticContext : public BaseStaticContext {
return data_and_deleter;
}

std::unique_ptr<BaseContext> CreateContext() override {
return caffe2::make_unique<CPUContext>();
}

std::unique_ptr<BaseContext> CreateContext(
const DeviceOption& option) override {
return caffe2::make_unique<CPUContext>(option);
}

DeviceType GetDeviceType() override {
return CPU;
}
Expand Down
1 change: 1 addition & 0 deletions caffe2/core/context_base.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "context_base.h"

namespace caffe2 {

} // namespace caffe2
5 changes: 5 additions & 0 deletions caffe2/core/context_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ CAFFE2_DEFINE_int(
128,
"The threshold in MB on how frequently to report memory changes");

namespace at {

REGISTER_CONTEXT(DeviceType::CUDA, caffe2::CUDAContext);
} // namespace at

namespace caffe2 {

ThreadLocalCUDAObjects& CUDAContext::getCudaObjects() {
Expand Down
15 changes: 2 additions & 13 deletions caffe2/core/context_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
// The default cuda context constructor.
explicit CUDAContext(const int gpu_id = -1);
explicit CUDAContext(const DeviceOption& option);
explicit CUDAContext(const at::Device& device)
: CUDAContext(DeviceToOption(device)) {}

~CUDAContext() override {
if (curand_generator_) {
Expand Down Expand Up @@ -385,19 +387,6 @@ class CAFFE2_CUDA_API CUDAStaticContext final : public BaseStaticContext {
public:
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;

std::unique_ptr<BaseContext> CreateContext() override {
return caffe2::make_unique<CUDAContext>();
}

std::unique_ptr<BaseContext> CreateContext(
const DeviceOption& option) override {
return caffe2::make_unique<CUDAContext>(option);
}

std::unique_ptr<BaseContext> CreateContext(int gpu_id = -1) {
return caffe2::make_unique<CUDAContext>(gpu_id);
}

DeviceType GetDeviceType() override {
return CUDA;
}
Expand Down
18 changes: 11 additions & 7 deletions caffe2/core/hip/context_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ CAFFE2_DEFINE_int(caffe2_gpu_memory_report_interval_mb,
128,
"The threshold in MB on how frequently to report memory changes");

namespace at {

REGISTER_CONTEXT(DeviceType::HIP, caffe2::HIPContext);
} // namespace at

namespace caffe2 {

thread_local ThreadLocalHIPObjects HIPContext::hip_objects_;
Expand Down Expand Up @@ -408,13 +413,12 @@ void HIPStaticContext::Delete(void* ptr) {
g_hip_device_affiliation.erase(it);
break;
}
case HipMemoryPoolType::THC:
{
HIP_ENFORCE(g_thc_allocator->Free(ptr));
if (FLAGS_caffe2_gpu_memory_tracking) {
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
}
break;
case HipMemoryPoolType::THC: {
HIP_ENFORCE(g_thc_allocator->Free(ptr));
if (FLAGS_caffe2_gpu_memory_tracking) {
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
}
break;
}
}
}
Expand Down
15 changes: 2 additions & 13 deletions caffe2/core/hip/context_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class HIPContext final : public BaseContext {
// The default HIP context constructor.
explicit HIPContext(const int gpu_id = -1);
explicit HIPContext(const DeviceOption& option);
explicit HIPContext(const at::Device& device)
: HIPContext(DeviceToOption(device)) {}

~HIPContext() override {
if (hiprand_generator_) {
Expand Down Expand Up @@ -374,19 +376,6 @@ class HIPStaticContext final : public BaseStaticContext {
public:
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;

std::unique_ptr<BaseContext> CreateContext() override {
return caffe2::make_unique<HIPContext>();
}

std::unique_ptr<BaseContext> CreateContext(
const DeviceOption& option) override {
return caffe2::make_unique<HIPContext>(option);
}

std::unique_ptr<BaseContext> CreateContext(int gpu_id = -1) {
return caffe2::make_unique<HIPContext>(gpu_id);
}

DeviceType GetDeviceType() override {
return HIP;
}
Expand Down
8 changes: 4 additions & 4 deletions caffe2/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ class CAFFE2_API Tensor final {
return impl_.get()->GetStaticContext();
}

std::unique_ptr<BaseContext> CreateContext() const {
return impl_.get()->CreateContext();
}

DeviceType GetDeviceType() const {
return impl_->device_type();
}

at::Device GetDevice() const {
return impl_.get()->GetDevice();
}

void CopyFrom(const Tensor& src, BaseContext* context = nullptr) const {
impl_.get()->CopyFrom(*src.impl_.get(), context);
}
Expand Down
1 change: 0 additions & 1 deletion caffe2/core/tensor_impl.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "caffe2/core/tensor_impl.h"

#include "caffe2/core/flags.h"

CAFFE2_DEFINE_bool(
Expand Down
11 changes: 2 additions & 9 deletions caffe2/ideep/utils/ideep_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class IDEEPContext final : public BaseContext {
: RandomNumberSeed()) {
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_IDEEP);
}
explicit IDEEPContext(const at::Device& device)
: IDEEPContext(DeviceToOption(device)) {}

~IDEEPContext() noexcept override {}

Expand Down Expand Up @@ -178,15 +180,6 @@ class IDEEPStaticContext : public BaseStaticContext {
return GetCPUAllocator()->New(nbytes);
}

std::unique_ptr<BaseContext> CreateContext() override {
return caffe2::make_unique<IDEEPContext>();
}

std::unique_ptr<BaseContext> CreateContext(
const DeviceOption& option) override {
return caffe2::make_unique<IDEEPContext>(option);
}

DeviceType GetDeviceType() override {
return IDEEP;
}
Expand Down
3 changes: 3 additions & 0 deletions caffe2/ideep/utils/ideep_register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include <ideep_pin_singletons.hpp>
#include "ideep_context.h"

namespace at {
REGISTER_CONTEXT(DeviceType::IDEEP, caffe2::IDEEPContext);
} // namespace at
namespace caffe2 {

CAFFE_KNOWN_TYPE(ideep::tensor);
Expand Down
4 changes: 4 additions & 0 deletions caffe2/mkl/utils/mkl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include "mkl_context.h"
#include "caffe2/core/event_cpu.h"

namespace at {

REGISTER_CONTEXT(DeviceType::MKLDNN, caffe2::MKLContext);
} // namespace at
namespace caffe2 {

// MKL events are the same as CPU events
Expand Down
11 changes: 2 additions & 9 deletions caffe2/mkl/utils/mkl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MKLContext : public BaseContext {
: RandomNumberSeed()) {
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_MKLDNN);
}
explicit MKLContext(const at::Device& device)
: MKLContext(DeviceToOption(device)) {}

~MKLContext() override {}

Expand Down Expand Up @@ -155,15 +157,6 @@ class MKLStaticContext : public BaseStaticContext {
return GetCPUAllocator()->New(nbytes);
}

std::unique_ptr<BaseContext> CreateContext() override {
return caffe2::make_unique<MKLContext>();
}

std::unique_ptr<BaseContext> CreateContext(
const DeviceOption& option) override {
return caffe2::make_unique<MKLContext>(option);
}

DeviceType GetDeviceType() override {
return MKLDNN;
}
Expand Down
Loading